]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : minor sampling refactor (2) (#9386)
authorslaren <redacted>
Mon, 9 Sep 2024 15:10:46 +0000 (17:10 +0200)
committerGitHub <redacted>
Mon, 9 Sep 2024 15:10:46 +0000 (17:10 +0200)
12 files changed:
examples/batched.swift/Sources/main.swift
examples/batched/batched.cpp
examples/gritlm/gritlm.cpp
examples/llama.android/llama/src/main/cpp/llama-android.cpp
examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
examples/passkey/passkey.cpp
examples/save-load-state/save-load-state.cpp
examples/server/server.cpp
examples/simple/simple.cpp
include/llama.h
src/llama-sampling.cpp
tests/test-sampling.cpp

index 4bc2bbf2c1570845ba18db0db3c7dbf6ca46497f..9f7c49492dda188cfbc18a6b465769c030d3781b 100644 (file)
@@ -140,8 +140,6 @@ while n_cur <= n_len {
 
         let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
 
-        llama_sampler_accept(smpl, new_token_id)
-
         // is it an end of stream? -> mark the stream as finished
         if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
             i_batch[i] = -1
index f5f309022c8e68ae51816038176a95f6bbe76862..615d6f0f50ef0929a64d454a820de3eda3628be4 100644 (file)
@@ -172,8 +172,6 @@ int main(int argc, char ** argv) {
 
             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
 
-            llama_sampler_accept(smpl, new_token_id);
-
             // is it an end of generation? -> mark the stream as finished
             if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
                 i_batch[i] = -1;
index e1efbf57394adca156b87c31aaa4368811ed678c..6f060e2dcec62d63aa108a5203e286481ff00afb 100644 (file)
@@ -121,7 +121,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
         llama_decode(ctx, bat);
 
         llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
-        llama_sampler_accept(smpl, token);
 
         if (token == eos_token) {
             break;
index 06ec160c2994042454eb0552e83c8767b13fe8f0..f611809c6deff72aa4d3686e6638b536b0371b23 100644 (file)
@@ -414,8 +414,6 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
     // sample the most likely token
     const auto new_token_id = llama_sampler_sample(sampler, context, -1);
 
-    llama_sampler_accept(sampler, new_token_id);
-
     const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
     if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
         return nullptr;
index 92f61fe83081d6bcc6f9eaabb1fc4d1d5291989d..dcd9803a2adc27609191f30cfacfa4fbfcac93c5 100644 (file)
@@ -152,8 +152,6 @@ actor LlamaContext {
 
         new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
 
-        llama_sampler_accept(sampling, new_token_id)
-
         if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
             print("\n")
             is_done = true
index 76d235c2c35cc7c57bfd8eee0b47b6155f60bf40..271ef3a98ccf5d5228ded2d133773d182967b129 100644 (file)
@@ -220,8 +220,6 @@ int main(int argc, char ** argv) {
         {
             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
 
-            llama_sampler_accept(smpl, new_token_id);
-
             // is it an end of generation?
             if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
                 LOG_TEE("\n");
index b54ec3bd808325c039c29785303fed23ef7601dd..e17ab0ed0b2ea8c9d3fe3d3128772e2dfc84c05c 100644 (file)
@@ -74,8 +74,6 @@ int main(int argc, char ** argv) {
         auto next_token     = llama_sampler_sample(smpl, ctx, -1);
         auto next_token_str = llama_token_to_piece(ctx, next_token);
 
-        llama_sampler_accept(smpl, next_token);
-
         printf("%s", next_token_str.c_str());
         result0 += next_token_str;
 
@@ -132,8 +130,6 @@ int main(int argc, char ** argv) {
         auto next_token     = llama_sampler_sample(smpl2, ctx2, -1);
         auto next_token_str = llama_token_to_piece(ctx2, next_token);
 
-        llama_sampler_accept(smpl2, next_token);
-
         printf("%s", next_token_str.c_str());
         result1 += next_token_str;
 
@@ -222,8 +218,6 @@ int main(int argc, char ** argv) {
         auto next_token     = llama_sampler_sample(smpl3, ctx3, -1);
         auto next_token_str = llama_token_to_piece(ctx3, next_token);
 
-        llama_sampler_accept(smpl3, next_token);
-
         printf("%s", next_token_str.c_str());
         result2 += next_token_str;
 
index 9ab8f8ca61b288a5a8de78586d59ab249a1833d1..de3ea313cfb1159ad18adbba4e6fab58038a8b31 100644 (file)
@@ -613,7 +613,7 @@ struct server_context {
 
     gpt_params params;
 
-    llama_batch batch;
+    llama_batch batch = {};
 
     bool clean_kv_cache = true;
     bool add_bos_token  = true;
index a53cef54771901c24c830be9c086097ee0c407e6..d040172a5beba2cabe97cba735b5eb54d32e0bb2 100644 (file)
@@ -118,8 +118,6 @@ int main(int argc, char ** argv) {
         {
             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
 
-            llama_sampler_accept(smpl, new_token_id);
-
             // is it an end of generation?
             if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
                 LOG_TEE("\n");
index 6334fc30d413c424838f6f4e80a545745684282e..93b3e6e85c485de7eccc9f5c707d5c5c0e3a6a2c 100644 (file)
@@ -1127,15 +1127,16 @@ extern "C" {
                              int32_t   n_logit_bias,
               const llama_logit_bias * logit_bias);
 
-    // Shorthand for:
+    /// @details Sample and accept a token from the idx-th output of the last evaluation
     //
+    // Shorthand for:
     //    const auto * logits = llama_get_logits_ith(ctx, idx);
     //    llama_token_data_array cur_p = { ... init from logits ... };
     //    llama_sampler_apply(smpl, &cur_p);
-    //    return cur_p.data[cur_p.selected].id;
-    //
-    // At this point, this is mostly a convenience function.
-    //
+    //    auto token = cur_p.data[cur_p.selected].id;
+    //    llama_sampler_accept(smpl, token);
+    //    return token;
+    // Returns the sampled token
     LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
 
     // TODO: extend in the future
index 41f48ec2867797e64ab983d6855d3ea2cc0741db..6f448b80c44c1979809663ab9a3033be0bcf5935 100644 (file)
@@ -8,49 +8,44 @@
 #include <cstring>
 #include <ctime>
 #include <cfloat>
+#include <cmath>
 #include <numeric>
 #include <random>
 #include <unordered_map>
 
-static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
-#if 1
-    probs.resize(cur_p->size);
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        probs[i] = cur_p->data[i].p;
-    }
-
-    std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
-#else
-    // avoid the copy with a custom iterator
+static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
+    // iterator for the probabilities
+#ifdef __GNUC__
     #pragma GCC diagnostic push
     #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
+#endif
 
     struct probs_iterator {
         typedef std::input_iterator_tag iterator_category;
         typedef float value_type;
         typedef float * pointer;
         typedef float & reference;
-        typedef size_t difference_type;
+        typedef ptrdiff_t difference_type;
 
-        const llama_token_data_array * data;
-        size_t i;
+        const llama_token_data * data;
 
-        bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; }
-        bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; }
-        float operator*() const { return data->data[i].p; }
-        probs_iterator & operator++() { ++i; return *this; }
-        probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; }
+        bool operator==(const probs_iterator & other) const { return data == other.data; }
+        bool operator!=(const probs_iterator & other) const { return data != other.data; }
+        const float & operator*() const { return data->p; }
+        probs_iterator & operator++() { ++data; return *this; }
+        probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
     };
-    #pragma GCC diagnostic pop
-
-    std::discrete_distribution<size_t> dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size});
 
-    GGML_UNUSED(probs);
+#ifdef __GNUC__
+    #pragma GCC diagnostic pop
 #endif
 
+    std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
+
     return dist(rng);
 }
 
+/*
 static void llama_log_softmax(float * array, size_t size) {
     float max_l = *std::max_element(array, array + size);
     float sum = 0.f;
@@ -64,6 +59,7 @@ static void llama_log_softmax(float * array, size_t size) {
         array[i] = logf(array[i] / sum);
     }
 }
+*/
 
 static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
     GGML_ASSERT(cur_p->size > 0);
@@ -231,67 +227,92 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
         cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
     }
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    llama_token_data_array cur_p = {
+        /* .data       = */ cur.data(),
+        /* .size       = */ cur.size(),
+        /* .selected   = */ -1,
+        /* .sorted     = */ false,
+    };
 
     llama_sampler_apply(smpl, &cur_p);
 
-    return cur_p.data[cur_p.selected].id;
+    GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
+
+    auto token = cur_p.data[cur_p.selected].id;
+
+    llama_sampler_accept(smpl, token);
+
+    return token;
 }
 
 // sampler chain
 
-static struct llama_sampler_i llama_sampler_chain_i = {
-    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
-    /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
-        auto * chain = (llama_sampler_chain *) smpl->ctx;
+static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
+    return "chain";
+}
 
-        time_meas tm(chain->t_sample_us, chain->params.no_perf);
+static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
 
-        for (auto * smpl : chain->samplers) {
-            llama_sampler_accept(smpl, token);
-        }
+    time_meas tm(chain->t_sample_us, chain->params.no_perf);
 
-        chain->n_sample++;
-    },
-    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-        auto * chain = (llama_sampler_chain *) smpl->ctx;
+    for (auto * smpl : chain->samplers) {
+        llama_sampler_accept(smpl, token);
+    }
 
-        time_meas tm(chain->t_sample_us, chain->params.no_perf);
+    chain->n_sample++;
+}
 
-        for (auto * smpl : chain->samplers) {
-            llama_sampler_apply(smpl, cur_p);
-        }
-    },
-    /* .reset  = */ [](struct llama_sampler * smpl) {
-        auto * chain = (llama_sampler_chain *) smpl->ctx;
+static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
 
-        for (auto * smpl : chain->samplers) {
-            llama_sampler_reset(smpl);
-        }
+    time_meas tm(chain->t_sample_us, chain->params.no_perf);
 
-        chain->t_sample_us = 0;
-        chain->n_sample    = 0;
-    },
-    /* .clone  = */ [](const struct llama_sampler * smpl) {
-        const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
+    for (auto * smpl : chain->samplers) {
+        llama_sampler_apply(smpl, cur_p);
+    }
+}
 
-        auto * result = llama_sampler_chain_init(chain_src->params);
+static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
 
-        for (auto * smpl : chain_src->samplers) {
-            llama_sampler_chain_add(result, llama_sampler_clone(smpl));
-        }
+    for (auto * smpl : chain->samplers) {
+        llama_sampler_reset(smpl);
+    }
 
-        return result;
-    },
-    /* .free   = */ [](struct llama_sampler * smpl) {
-        auto * chain = (llama_sampler_chain *) smpl->ctx;
+    chain->t_sample_us = 0;
+    chain->n_sample    = 0;
+}
 
-        for (auto * smpl : chain->samplers) {
-            llama_sampler_free(smpl);
-        }
+static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
+    const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
+
+    auto * result = llama_sampler_chain_init(chain_src->params);
+
+    for (auto * smpl : chain_src->samplers) {
+        llama_sampler_chain_add(result, llama_sampler_clone(smpl));
+    }
+
+    return result;
+}
+
+static void llama_sampler_chain_free(struct llama_sampler * smpl) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
 
-        delete chain;
-    },
+    for (auto * smpl : chain->samplers) {
+        llama_sampler_free(smpl);
+    }
+
+    delete chain;
+}
+
+static struct llama_sampler_i llama_sampler_chain_i = {
+    /* .name   = */ llama_sampler_chain_name,
+    /* .accept = */ llama_sampler_chain_accept,
+    /* .apply  = */ llama_sampler_chain_apply,
+    /* .reset  = */ llama_sampler_chain_reset,
+    /* .clone  = */ llama_sampler_chain_clone,
+    /* .free   = */ llama_sampler_chain_free,
 };
 
 struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
@@ -368,8 +389,6 @@ struct llama_sampler_dist {
     const uint32_t seed;
 
     std::mt19937 rng;
-
-    std::vector<float> probs; // work array
 };
 
 static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
@@ -378,7 +397,7 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
 
 static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     auto * ctx = (llama_sampler_dist *) smpl->ctx;
-    cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+    cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
 }
 
 static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
@@ -419,7 +438,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
         /* .ctx   = */ new llama_sampler_dist {
             /* .seed = */ seed,
             /* .rng  = */ std::mt19937(seed),
-            /* .probs = */ {},
         },
     };
 }
@@ -1023,8 +1041,6 @@ struct llama_sampler_mirostat {
     float mu;
 
     std::mt19937 rng;
-
-    std::vector<float> probs;
 };
 
 static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
@@ -1055,7 +1071,7 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke
     llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
     llama_sampler_softmax_impl(cur_p);
 
-    const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+    const int idx = llama_sample_dist(cur_p, ctx->rng);
 
     cur_p->selected = idx;
 
@@ -1111,7 +1127,6 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
             /* .m       = */ m,
             /* .mu      = */ 2.0f*tau,
             /* .rng     = */ std::mt19937(seed),
-            /* .probs   = */ {},
         },
     };
 }
@@ -1127,8 +1142,6 @@ struct llama_sampler_mirostat_v2 {
     float mu;
 
     std::mt19937 rng;
-
-    std::vector<float> probs;
 };
 
 static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
@@ -1152,7 +1165,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
     // Normalize the probabilities of the remaining words
     llama_sampler_softmax_impl(cur_p);
 
-    const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+    const int idx = llama_sample_dist(cur_p, ctx->rng);
 
     cur_p->selected = idx;
 
@@ -1207,7 +1220,6 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
             /* .eta   = */ eta,
             /* .mu    = */ 2.0f*tau,
             /* .rng   = */ std::mt19937(seed),
-            /* .probs = */ {},
         },
     };
 }
@@ -1527,6 +1539,10 @@ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /
 static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
 
+    if (ctx->logit_bias.empty()) {
+        return;
+    }
+
     ctx->to_search.clear();
 
     // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
@@ -1538,6 +1554,10 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
         }
     }
 
+    if (ctx->to_search.empty()) {
+        return;
+    }
+
     // search for the remaining candidates that were not found in the previous step
     for (size_t i = 0; i < cur_p->size; ++i) {
         for (const auto & lb : ctx->to_search) {
index 37400c179e9bdd2a6c96e61d0cfdfbe38fa5e2c7..d738b7a4502ed81a6b0393927d5fd13214d3eb02 100644 (file)
@@ -245,7 +245,7 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
         }
     }
 
-    printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n",
+    printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n",
            samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
 }