]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix segfault on long system prompt (#8987)
authorcompilade <redacted>
Wed, 14 Aug 2024 06:51:02 +0000 (02:51 -0400)
committerGitHub <redacted>
Wed, 14 Aug 2024 06:51:02 +0000 (09:51 +0300)
* server : fix segfault on long system prompt

* server : fix parallel generation with very small batch sizes

* server : fix typo in comment

examples/server/server.cpp

index 1621c7c43961c1f4c7ec7b229065c0e8f5e04b61..c25338f57376702624ef7b1f72c88b1c257a086b 100644 (file)
@@ -754,13 +754,13 @@ struct server_context {
         default_generation_settings_for_props = get_formated_generation(slots.front());
         default_generation_settings_for_props["seed"] = -1;
 
-        // the update_slots() logic will always submit a maximum of n_batch tokens
+        // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
         // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
         {
             const int32_t n_batch = llama_n_batch(ctx);
 
             // only a single seq_id per token is needed
-            batch = llama_batch_init(n_batch, 0, 1);
+            batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
         }
 
         metrics.init();
@@ -1137,28 +1137,19 @@ struct server_context {
         if (!system_prompt.empty()) {
             system_tokens = ::llama_tokenize(ctx, system_prompt, true);
 
-            llama_batch_clear(batch);
+            const int32_t n_batch = llama_n_batch(ctx);
+            const int32_t n_tokens_prompt = system_tokens.size();
 
-            for (int i = 0; i < (int)system_tokens.size(); ++i) {
-                llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
-            }
+            for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
+                const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
 
-            const int32_t n_batch = llama_n_batch(ctx);
+                llama_batch_clear(batch);
 
-            for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
-                const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
-                llama_batch batch_view = {
-                    n_tokens,
-                    batch.token    + i,
-                    nullptr,
-                    batch.pos      + i,
-                    batch.n_seq_id + i,
-                    batch.seq_id   + i,
-                    batch.logits   + i,
-                    0, 0, 0, // unused
-                };
+                for (int32_t j = 0; j < n_tokens; ++j) {
+                    llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
+                }
 
-                if (llama_decode(ctx, batch_view) != 0) {
+                if (llama_decode(ctx, batch) != 0) {
                     LOG_ERROR("llama_decode() failed", {});
                     return;
                 }