]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add llama_sampler_init for safe usage of llama_sampler_free (#11727)
authorChristian Fillion <redacted>
Fri, 7 Feb 2025 09:33:27 +0000 (04:33 -0500)
committerGitHub <redacted>
Fri, 7 Feb 2025 09:33:27 +0000 (11:33 +0200)
The C API in llama.h claims users can implement `llama_sampler_i` to
create custom `llama_sampler`. The sampler chain takes ownership and
calls `llama_sampler_free` on them. However, `llama_sampler_free` is
hard-coded to use `delete`. This is undefined behavior if the object
wasn't also allocated via `new` from libllama's C++ runtime. Callers
in C and C-compatible languages do not use C++'s `new` operator. C++
callers may not be sharing the same heap as libllama.

common/llguidance.cpp
include/llama.h
src/llama-sampling.cpp

index 7aa8ddd80297b7613748d042b0d2dd8ebd6c9637..2feeb93c87e3018b66afd0e5e503f71d24be8355 100644 (file)
@@ -254,10 +254,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
         };
     }
 
-    return new llama_sampler{
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_llg_i,
-        /* .ctx   = */ ctx,
-    };
+        /* .ctx   = */ ctx
+    );
 }
 
 #else
index 61907ed404dbfd58fd8136b5df0541e5f14402ec..3784f7d3950e502cb10a5d8a4bb5a3e83687f366 100644 (file)
@@ -1114,11 +1114,12 @@ extern "C" {
     };
 
     struct llama_sampler {
-        struct llama_sampler_i  * iface;
-        llama_sampler_context_t   ctx;
+        const struct llama_sampler_i * iface;
+        llama_sampler_context_t        ctx;
     };
 
     // mirror of llama_sampler_i:
+    LLAMA_API struct llama_sampler * llama_sampler_init  (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
     LLAMA_API const char *           llama_sampler_name  (const struct llama_sampler * smpl);
     LLAMA_API void                   llama_sampler_accept(      struct llama_sampler * smpl, llama_token token);
     LLAMA_API void                   llama_sampler_apply (      struct llama_sampler * smpl, llama_token_data_array * cur_p);
index 26974f5396565b2b9f5e6057404972107ffab9bf..990b6129746dee801f7742935cf82c104d476cf7 100644 (file)
@@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {
 
 // llama_sampler API
 
+struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
+    return new llama_sampler {
+        /* .iface = */ iface,
+        /* .ctx   = */ ctx,
+    };
+}
+
 const char * llama_sampler_name(const struct llama_sampler * smpl) {
     if (!smpl->iface) {
         return "(null)";
@@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
     }
 
     if (smpl->ctx == nullptr) {
-        return new llama_sampler {
+        return llama_sampler_init(
             /* .iface = */ smpl->iface,
-            /* .ctx   = */ nullptr,
-        };
+            /* .ctx   = */ nullptr
+        );
     }
 
     GGML_ABORT("the sampler does not support cloning");
@@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
 };
 
 struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_chain_i,
         /* .ctx   = */ new llama_sampler_chain {
             /* .params      = */ params,
             /* .samplers    = */ {},
             /* .t_sample_us = */ 0,
             /* .n_sample    = */ 0,
-        },
-    };
+        }
+    );
 }
 
 void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
@@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
 };
 
 struct llama_sampler * llama_sampler_init_greedy() {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_greedy_i,
-        /* .ctx   = */ nullptr,
-    };
+        /* .ctx   = */ nullptr
+    );
 }
 
 // dist
@@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
 
 struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_dist_i,
         /* .ctx   = */ new llama_sampler_dist {
             /* .seed     = */ seed,
             /* .seed_cur = */ seed_cur,
             /* .rng      = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // softmax
@@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
 };
 
 struct llama_sampler * llama_sampler_init_softmax() {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_softmax_i,
-        /* .ctx   = */ nullptr,
-    };
+        /* .ctx   = */ nullptr
+    );
 }
 
 // top-k
@@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
 };
 
 struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_top_k_i,
         /* .ctx   = */ new llama_sampler_top_k {
             /* .k = */ k,
-        },
-    };
+        }
+    );
 }
 
 // top-p
@@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
 };
 
 struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_top_p_i,
         /* .ctx   = */ new llama_sampler_top_p {
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
-        },
-    };
+        }
+    );
 }
 
 // min-p
@@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
 };
 
 struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_min_p_i,
         /* .ctx   = */ new llama_sampler_min_p {
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
-        },
-    };
+        }
+    );
 }
 
 // typical
@@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
 };
 
 struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_typical_i,
         /* .ctx   = */ new llama_sampler_typical {
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
-        },
-    };
+        }
+    );
 }
 
 // temp
@@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
 };
 
 struct llama_sampler * llama_sampler_init_temp(float temp) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_temp_i,
         /* .ctx   = */ new llama_sampler_temp {
             /*.temp = */ temp,
-        },
-    };
+        }
+    );
 }
 
 // temp-ext
@@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
 };
 
 struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_temp_ext_i,
         /* .ctx   = */ new llama_sampler_temp_ext {
             /* .temp     = */ temp,
             /* .delta    = */ delta,
             /* .exponent = */ exponent,
-        },
-    };
+        }
+    );
 }
 
 // xtc
@@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
 
 struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_xtc_i,
         /* .ctx   = */ new llama_sampler_xtc {
             /* .probability   = */ p,
@@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
             /* .seed          = */ seed,
             /* .seed_cur      = */ seed_cur,
             /* .rng           = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // mirostat
@@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
 
 struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_mirostat_i,
         /* .ctx   = */ new llama_sampler_mirostat {
             /* .n_vocab  = */ n_vocab,
@@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
             /* .m        = */ m,
             /* .mu       = */ 2.0f*tau,
             /* .rng      = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // mirostat v2
@@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
 
 struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_mirostat_v2_i,
         /* .ctx   = */ new llama_sampler_mirostat_v2 {
             /* .seed     = */ seed,
@@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
             /* .eta      = */ eta,
             /* .mu       = */ 2.0f*tau,
             /* .rng      = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // grammar
@@ -1528,10 +1535,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
         };
     }
 
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_grammar_i,
-        /* .ctx   = */ ctx,
-    };
+        /* .ctx   = */ ctx
+    );
 }
 
 struct llama_sampler * llama_sampler_init_grammar(
@@ -1678,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties(
         float penalty_present) {
     penalty_last_n = std::max(penalty_last_n, 0);
 
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_penalties_i,
         /* .ctx   = */ new llama_sampler_penalties {
             /* .penalty_last_n  = */ penalty_last_n,
@@ -1687,8 +1694,8 @@ struct llama_sampler * llama_sampler_init_penalties(
             /* .penalty_present = */ penalty_present,
             /* .prev            = */ ring_buffer<llama_token>(penalty_last_n),
             /* .token_count     = */ {},
-        },
-    };
+        }
+    );
 }
 
 // DRY
@@ -2041,7 +2048,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
         }
     }
 
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_dry_i,
         /* .ctx   = */ new llama_sampler_dry {
             /* .total_context_size     = */ context_size,
@@ -2053,8 +2060,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
             /* .dry_repeat_count       = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
             /* .dry_max_token_repeat   = */ {},
             /* .last_tokens            = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
-        },
-    };
+        }
+    );
 }
 
 // wrapper for test-sampling.cpp
@@ -2155,14 +2162,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
                          int32_t   n_vocab,
                          int32_t   n_logit_bias,
           const llama_logit_bias * logit_bias) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_logit_bias_i,
         /* .ctx   = */ new llama_sampler_logit_bias {
             /* .n_vocab    = */ n_vocab,
             /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
             /* .to_search  = */ {},
-        },
-    };
+        }
+    );
 }
 
 // infill
@@ -2377,14 +2384,14 @@ static struct llama_sampler_i llama_sampler_infill_i = {
 };
 
 struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_infill_i,
         /* .ctx   = */ new llama_sampler_infill {
             /* .vocab = */ vocab,
             /* .buf0  = */ std::vector<char>(512),
             /* .buf1  = */ std::vector<char>(512),
-        },
-    };
+        }
+    );
 }
 
 // utils