]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama-bench : use random tokens to improve accuracy with mixtral (#6069)
authorslaren <redacted>
Fri, 15 Mar 2024 08:22:24 +0000 (09:22 +0100)
committerGitHub <redacted>
Fri, 15 Mar 2024 08:22:24 +0000 (10:22 +0200)
examples/llama-bench/llama-bench.cpp

index d6e5e0497dc3ae5443af4e1ccf1cb952071fad68..32eea786919f63408330c021ad45b135554425f3 100644 (file)
@@ -8,6 +8,7 @@
 #include <cstdio>
 #include <cstring>
 #include <ctime>
+#include <cstdlib>
 #include <iterator>
 #include <map>
 #include <numeric>
@@ -1123,15 +1124,19 @@ struct sql_printer : public printer {
 static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
     llama_set_n_threads(ctx, n_threads, n_threads);
 
-    //std::vector<llama_token> tokens(n_prompt, llama_token_bos(llama_get_model(ctx)));
-    //llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0));
-    //GGML_UNUSED(n_batch);
+    const llama_model * model = llama_get_model(ctx);
+    const int32_t n_vocab = llama_n_vocab(model);
+
+    std::vector<llama_token> tokens(n_batch);
 
-    std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
     int n_processed = 0;
 
     while (n_processed < n_prompt) {
         int n_tokens = std::min(n_prompt - n_processed, n_batch);
+        tokens[0] = n_processed == 0 && llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
+        for (int i = 1; i < n_tokens; i++) {
+            tokens[i] = std::rand() % n_vocab;
+        }
         llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
         n_processed += n_tokens;
     }
@@ -1142,11 +1147,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
 static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
     llama_set_n_threads(ctx, n_threads, n_threads);
 
-    llama_token token = llama_token_bos(llama_get_model(ctx));
+    const llama_model * model = llama_get_model(ctx);
+    const int32_t n_vocab = llama_n_vocab(model);
+
+    llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
 
     for (int i = 0; i < n_gen; i++) {
         llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
         llama_synchronize(ctx);
+        token = std::rand() % n_vocab;
     }
 }