]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling : avoid expensive softmax during greedy sampling (#9605)
authorGeorgi Gerganov <redacted>
Tue, 24 Sep 2024 06:03:17 +0000 (09:03 +0300)
committerGitHub <redacted>
Tue, 24 Sep 2024 06:03:17 +0000 (09:03 +0300)
* sampling : avoid expensive softmax during greedy sampling

ggml-ci

* speculative : fix default RNG seed + set sparams.n_probs

* Update tests/test-sampling.cpp

Co-authored-by: slaren <redacted>
* sampling : add clarifying comment [no ci]

---------

Co-authored-by: slaren <redacted>
common/sampling.cpp
examples/speculative/speculative.cpp
include/llama.h
src/llama-sampling.cpp
tests/test-sampling.cpp

index e51d07611d42c0b8067df53f15e65451e0c59252..3dc7f112094e61c93853fd04afd9e3c0a05f9e12 100644 (file)
@@ -209,7 +209,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
             GGML_ASSERT(false && "unknown mirostat version");
         }
     } else {
-        llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
+        if (params.n_probs > 0) {
+            // some use cases require to sample greedily, but still obtain the probabilities of the top tokens
+            // ref: https://github.com/ggerganov/llama.cpp/pull/9605
+            //
+            // the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
+            // it is much faster, since we avoid sorting all tokens and should give a good approximation
+            llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
+            llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
+        }
         llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
     }
 
index fbac21811638bdbf5f8861fdef7a975eda8726b0..adf6255e1449f293a17387abcd33a541a7738fbb 100644 (file)
@@ -32,6 +32,9 @@ struct seq_draft {
 int main(int argc, char ** argv) {
     gpt_params params;
 
+    // needed to get candidate probs even for temp <= 0.0
+    params.sparams.n_probs = 128;
+
     if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
         return 1;
     }
@@ -49,7 +52,7 @@ int main(int argc, char ** argv) {
     // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
     const float p_split  = params.p_split;
 
-    std::default_random_engine rng(params.sparams.seed);
+    std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
     std::uniform_real_distribution<> u_dist;
 
     // init llama.cpp
index f316a87ba31509e9738ff63d206cc2b6ebd41fb2..132937a0700e7c9bd69f517f05a7141a04a25a55 100644 (file)
@@ -1066,6 +1066,7 @@ extern "C" {
     LLAMA_API struct llama_sampler * llama_sampler_init_dist       (uint32_t seed);
 
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
+    /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
     LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void);
 
     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
index 5299f51160dac41a65f439eff28f27639ab3aa0b..e255a8fc4fd548e135bb00465698b1b61a5043f2 100644 (file)
@@ -3,13 +3,14 @@
 #include "llama-vocab.h"
 #include "llama-grammar.h"
 
-#include <cassert>
 #include <algorithm>
-#include <cstring>
-#include <ctime>
+#include <cassert>
 #include <cfloat>
 #include <chrono>
 #include <cmath>
+#include <cstdlib>
+#include <cstring>
+#include <ctime>
 #include <numeric>
 #include <random>
 #include <unordered_map>
index d738b7a4502ed81a6b0393927d5fd13214d3eb02..6e021c4c70357d123783510b318100c91e901e90 100644 (file)
@@ -1,6 +1,5 @@
 #include "ggml.h"
 #include "llama.h"
-#include "llama-sampling.h"
 
 #ifdef NDEBUG
 #undef NDEBUG
@@ -249,6 +248,45 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
            samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
 }
 
+static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {
+    std::vector<llama_token_data> cur(data.size());
+    std::copy(data.begin(), data.end(), cur.begin());
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    llama_sampler_apply(cnstr, &cur_p);
+    llama_sampler_reset(cnstr);
+    const int64_t t_start = ggml_time_us();
+    for (int i = 0; i < n_iter; i++) {
+        std::copy(data.begin(), data.end(), cur.begin());
+        llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+        llama_sampler_apply(cnstr, &cur_p);
+        llama_sampler_reset(cnstr);
+    }
+    const int64_t t_end = ggml_time_us();
+    llama_sampler_free(cnstr);
+    printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
+}
+
+#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
+
+static void test_perf() {
+    const int n_vocab = 1 << 17;
+
+    std::vector<llama_token_data> data;
+
+    data.reserve(n_vocab);
+    for (int i = 0; i < n_vocab; i++) {
+        const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f);
+        data.emplace_back(llama_token_data{i, logit, 0.0f});
+    }
+
+    BENCH(llama_sampler_init_top_k    (40),      data, 32);
+    BENCH(llama_sampler_init_top_p    (0.8f, 1), data, 32);
+    BENCH(llama_sampler_init_min_p    (0.2f, 1), data, 32);
+    BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
+    BENCH(llama_sampler_init_typical  (0.5f, 1), data, 32);
+    BENCH(llama_sampler_init_softmax  (),        data, 32);
+}
+
 int main(void) {
     ggml_time_init();
 
@@ -316,5 +354,7 @@ int main(void) {
 
     printf("OK\n");
 
+    test_perf();
+
     return 0;
 }