]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix empty batch causing llama_batch_allocr to crash (#9966)
authorXuan Son Nguyen <redacted>
Tue, 22 Oct 2024 14:59:02 +0000 (16:59 +0200)
committerGitHub <redacted>
Tue, 22 Oct 2024 14:59:02 +0000 (16:59 +0200)
* llama : fix empty batch cause llama_batch_allocr to crash

* move batch_allocr inside decode/encode_internal

* fix build

* add GGML_ASSERT

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
src/llama.cpp

index 7a5a46dce7dbc1ea99c48cd216b49e368311a7f9..24e1f1f01a857dfb9d66eda5aaccb0124600c5c0 100644 (file)
@@ -5177,6 +5177,57 @@ struct llama_model_loader {
     }
 };
 
+// temporary allocate memory for the input batch if needed
+static const llama_seq_id batch_default_seq_id = 0;
+struct llama_batch_allocr {
+    std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
+    std::vector<llama_pos>      pos;
+    std::vector<int32_t>        n_seq_id;
+    std::vector<llama_seq_id *> seq_id;
+    std::vector<int8_t>         logits;
+    struct llama_batch          batch;
+    // optionally fulfill the batch returned by llama_batch_get_one
+    llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
+        batch = in_batch;
+        GGML_ASSERT(batch.n_tokens > 0);
+        if (!batch.pos) {
+            // determine the last position in KV cache
+            llama_pos last_pos = -1;
+            for (const auto & cell : ctx.kv_self.cells) {
+                if (cell.has_seq_id(batch_default_seq_id)) {
+                    last_pos = std::max(last_pos, cell.pos);
+                }
+            }
+            last_pos++; // next position
+            pos.resize(batch.n_tokens);
+            for (int32_t i = 0; i < batch.n_tokens; i++) {
+                pos[i] = i+last_pos;
+            }
+            batch.pos = pos.data();
+        }
+        if (!batch.n_seq_id) {
+            n_seq_id.resize(batch.n_tokens);
+            for (int32_t i = 0; i < batch.n_tokens; i++) {
+                n_seq_id[i] = seq_id_0.size();
+            }
+            batch.n_seq_id = n_seq_id.data();
+        }
+        if (!batch.seq_id) {
+            seq_id.resize(batch.n_tokens + 1);
+            seq_id[batch.n_tokens] = NULL;
+            for (int32_t i = 0; i < batch.n_tokens; i++) {
+                seq_id[i] = seq_id_0.data();
+            }
+            batch.seq_id = seq_id.data();
+        }
+        if (!batch.logits) {
+            logits.resize(batch.n_tokens);
+            logits[logits.size() - 1] = true;
+            batch.logits = logits.data();
+        }
+    }
+};
+
 template<>
 bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
     uint32_t tmp;
@@ -17095,16 +17146,20 @@ static void llama_graph_compute(
 //
 static int llama_decode_internal(
          llama_context & lctx,
-           llama_batch   batch) {
+           llama_batch   inp_batch) {
 
     lctx.is_encoding = false;
-    const uint32_t n_tokens_all = batch.n_tokens;
 
-    if (n_tokens_all == 0) {
+    if (inp_batch.n_tokens == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
+    // temporary allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(lctx, inp_batch);
+    const llama_batch & batch = batch_allocr.batch;
+    const uint32_t n_tokens_all = batch.n_tokens;
+
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
@@ -17409,17 +17464,20 @@ static int llama_decode_internal(
 //
 static int llama_encode_internal(
          llama_context & lctx,
-           llama_batch   batch) {
+           llama_batch   inp_batch) {
 
     lctx.is_encoding = true;
 
-    const uint32_t n_tokens = batch.n_tokens;
-
-    if (n_tokens == 0) {
+    if (inp_batch.n_tokens == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
+    // temporary allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(lctx, inp_batch);
+    const llama_batch & batch = batch_allocr.batch;
+    const uint32_t n_tokens = batch.n_tokens;
+
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
@@ -21090,61 +21148,10 @@ void llama_batch_free(struct llama_batch batch) {
     if (batch.logits)   free(batch.logits);
 }
 
-// temporary allocate memory for the input batch if needed
-static const llama_seq_id batch_default_seq_id = 0;
-struct llama_batch_allocr {
-    std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
-    std::vector<llama_pos>      pos;
-    std::vector<int32_t>        n_seq_id;
-    std::vector<llama_seq_id *> seq_id;
-    std::vector<int8_t>         logits;
-    struct llama_batch          batch;
-    // optionally fulfill the batch returned by llama_batch_get_one
-    llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
-        batch = in_batch;
-        if (!batch.pos) {
-            // determine the last position in KV cache
-            llama_pos last_pos = -1;
-            for (const auto & cell : ctx->kv_self.cells) {
-                if (cell.has_seq_id(batch_default_seq_id)) {
-                    last_pos = std::max(last_pos, cell.pos);
-                }
-            }
-            last_pos++; // next position
-            pos.resize(batch.n_tokens);
-            for (int32_t i = 0; i < batch.n_tokens; i++) {
-                pos[i] = i+last_pos;
-            }
-            batch.pos = pos.data();
-        }
-        if (!batch.n_seq_id) {
-            n_seq_id.resize(batch.n_tokens);
-            for (int32_t i = 0; i < batch.n_tokens; i++) {
-                n_seq_id[i] = seq_id_0.size();
-            }
-            batch.n_seq_id = n_seq_id.data();
-        }
-        if (!batch.seq_id) {
-            seq_id.resize(batch.n_tokens + 1);
-            seq_id[batch.n_tokens] = NULL;
-            for (int32_t i = 0; i < batch.n_tokens; i++) {
-                seq_id[i] = seq_id_0.data();
-            }
-            batch.seq_id = seq_id.data();
-        }
-        if (!batch.logits) {
-            logits.resize(batch.n_tokens);
-            logits[logits.size() - 1] = true;
-            batch.logits = logits.data();
-        }
-    }
-};
-
 int32_t llama_encode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
-    llama_batch_allocr batch_allocr(ctx, batch);
-    const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
+    const int ret = llama_encode_internal(*ctx, batch);
     if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
     }
@@ -21155,8 +21162,7 @@ int32_t llama_encode(
 int32_t llama_decode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
-    llama_batch_allocr batch_allocr(ctx, batch);
-    const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
+    const int ret = llama_decode_internal(*ctx, batch);
     if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }