]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
bench : handle decode errors (#13548)
authorGeorgi Gerganov <redacted>
Thu, 15 May 2025 02:57:02 +0000 (05:57 +0300)
committerGitHub <redacted>
Thu, 15 May 2025 02:57:02 +0000 (05:57 +0300)
ggml-ci

tools/llama-bench/llama-bench.cpp

index 9457e6815e231073d60d9222a5b9edb411bd3a3d..53dbdda2a35f399c0f33ccc8eb1d3bde8a7c7627 100644 (file)
@@ -1736,7 +1736,7 @@ struct sql_printer : public printer {
     }
 };
 
-static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
+static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
     llama_set_n_threads(ctx, n_threads, n_threads);
 
     const llama_model * model   = llama_get_model(ctx);
@@ -1753,14 +1753,19 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
         for (int i = 1; i < n_tokens; i++) {
             tokens[i] = std::rand() % n_vocab;
         }
-        llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
+        int res = llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
+        if (res != 0) {
+            fprintf(stderr, "%s: failed to decode prompt batch, res = %d\n", __func__, res);
+            return false;
+        }
         n_processed += n_tokens;
     }
 
     llama_synchronize(ctx);
+    return true;
 }
 
-static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
+static bool test_gen(llama_context * ctx, int n_gen, int n_threads) {
     llama_set_n_threads(ctx, n_threads, n_threads);
 
     const llama_model * model   = llama_get_model(ctx);
@@ -1770,10 +1775,15 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
     llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
 
     for (int i = 0; i < n_gen; i++) {
-        llama_decode(ctx, llama_batch_get_one(&token, 1));
+        int res = llama_decode(ctx, llama_batch_get_one(&token, 1));
+        if (res != 0) {
+            fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res);
+            return false;
+        }
         llama_synchronize(ctx);
         token = std::rand() % n_vocab;
     }
+    return true;
 }
 
 static void llama_null_log_callback(enum ggml_log_level level, const char * text, void * user_data) {
@@ -1917,13 +1927,21 @@ int main(int argc, char ** argv) {
                 fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
             }
             //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
-            test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
+            bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
+            if (!res) {
+                fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__);
+                exit(1);
+            }
         }
         if (t.n_gen > 0) {
             if (params.progress) {
                 fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
             }
-            test_gen(ctx, 1, t.n_threads);
+            bool res = test_gen(ctx, 1, t.n_threads);
+            if (!res) {
+                fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__);
+                exit(1);
+            }
         }
 
         for (int i = 0; i < params.reps; i++) {
@@ -1934,7 +1952,11 @@ int main(int argc, char ** argv) {
                     fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
                             i + 1, params.reps);
                 }
-                test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
+                bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
+                if (!res) {
+                    fprintf(stderr, "%s: error: failed to run depth\n", __func__);
+                    exit(1);
+                }
             }
 
             uint64_t t_start = get_time_ns();
@@ -1944,14 +1966,22 @@ int main(int argc, char ** argv) {
                     fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
                             i + 1, params.reps);
                 }
-                test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
+                bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
+                if (!res) {
+                    fprintf(stderr, "%s: error: failed to run prompt\n", __func__);
+                    exit(1);
+                }
             }
             if (t.n_gen > 0) {
                 if (params.progress) {
                     fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
                             i + 1, params.reps);
                 }
-                test_gen(ctx, t.n_gen, t.n_threads);
+                bool res = test_gen(ctx, t.n_gen, t.n_threads);
+                if (!res) {
+                    fprintf(stderr, "%s: error: failed to run gen\n", __func__);
+                    exit(1);
+                }
             }
 
             uint64_t t_ns = get_time_ns() - t_start;