]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : validate seq id batch input (#13809)
authorGeorgi Gerganov <redacted>
Tue, 27 May 2025 06:40:59 +0000 (09:40 +0300)
committerGitHub <redacted>
Tue, 27 May 2025 06:40:59 +0000 (09:40 +0300)
* llama : validate seq id batch input

ggml-ci

* cont : fix the fix

ggml-ci

src/llama-context.cpp

index ad77cae20eb504fc9bea8d21a1abee597587e27e..e153351af38093a0b788dfc576f80d8301db1946 100644 (file)
@@ -693,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
+    // TODO: move the validation to the llama_batch_allocr
     if (batch.token) {
         for (int32_t i = 0; i < n_tokens; ++i) {
             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
+
+            if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
+                LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
+                throw -1;
+            }
         }
     }
 
@@ -887,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
+    // TODO: move the validation to the llama_batch_allocr
     if (batch.token) {
         for (int64_t i = 0; i < n_tokens_all; ++i) {
             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
-                throw std::runtime_error("invalid token");
+                return -1;
+            }
+
+            if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
+                LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
+                return -1;
             }
         }
     }