]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
batched-bench : fix llama_synchronize usage during prompt processing (#15835)
authorGeorgi Gerganov <redacted>
Mon, 8 Sep 2025 07:27:07 +0000 (10:27 +0300)
committerGitHub <redacted>
Mon, 8 Sep 2025 07:27:07 +0000 (10:27 +0300)
ggml-ci

tools/batched-bench/batched-bench.cpp

index 46dd12caae544311631b106267481f5f9224c4bf..fcfcd80771c516d0ccc59a98b7545ed77fd5acc2 100644 (file)
@@ -71,7 +71,7 @@ int main(int argc, char ** argv) {
     llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
 
     // decode in batches of ctx_params.n_batch tokens
-    auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
+    auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize) {
         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
             const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
 
@@ -91,7 +91,9 @@ int main(int argc, char ** argv) {
                 return false;
             }
 
-            llama_synchronize(ctx);
+            if (synchronize) {
+                llama_synchronize(ctx);
+            }
         }
 
         return true;
@@ -103,7 +105,7 @@ int main(int argc, char ** argv) {
             common_batch_add(batch, get_token_rand(), i, { 0 }, false);
         }
 
-        if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
+        if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
             LOG_ERR("%s: llama_decode() failed\n", __func__);
             return 1;
         }
@@ -138,15 +140,17 @@ int main(int argc, char ** argv) {
                     }
                 }
 
-                const auto t_pp_start = ggml_time_us();
-
                 llama_memory_clear(mem, false);
 
-                if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
+                const auto t_pp_start = ggml_time_us();
+
+                if (!decode_helper(ctx, batch, ctx_params.n_batch, false)) {
                     LOG_ERR("%s: llama_decode() failed\n", __func__);
                     return 1;
                 }
 
+                llama_synchronize(ctx);
+
                 const auto t_pp_end = ggml_time_us();
 
                 if (is_pp_shared) {
@@ -158,7 +162,7 @@ int main(int argc, char ** argv) {
                         // run one dummy token to apply the memory copy
                         common_batch_clear(batch);
                         common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
-                        if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
+                        if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
                             LOG_ERR("%s: llama_decode() failed\n", __func__);
                             return 1;
                         }
@@ -175,7 +179,7 @@ int main(int argc, char ** argv) {
                         common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
                     }
 
-                    if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
+                    if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
                         LOG_ERR("%s: llama_decode() failed\n", __func__);
                         return 1;
                     }