]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : sanity checks for access to logits (#4274)
authorJared Van Bortel <redacted>
Sat, 16 Dec 2023 03:16:15 +0000 (22:16 -0500)
committerGitHub <redacted>
Sat, 16 Dec 2023 03:16:15 +0000 (22:16 -0500)
Co-authored-by: Georgi Gerganov <redacted>
llama.cpp

index eddb7085992d763945bb6db3e2ed4f5e98af1c33..58fe7492e2bf14c92ed8de90f557d2074187d996 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1505,6 +1505,10 @@ struct llama_context {
 
     // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
+#ifndef NDEBUG
+    // guard against access to unset logits
+    std::vector<bool>  logits_valid;
+#endif
     bool logits_all = false;
 
     // input embedding (1-dimensional array: [n_embd])
@@ -6150,6 +6154,14 @@ static int llama_decode_internal(
     {
         auto & logits_out = lctx.logits;
 
+#ifndef NDEBUG
+        auto & logits_valid = lctx.logits_valid;
+        logits_valid.clear();
+        logits_valid.resize(n_tokens);
+
+        logits_out.clear();
+#endif
+
         if (batch.logits) {
             logits_out.resize(n_vocab * n_tokens);
             for (uint32_t i = 0; i < n_tokens; i++) {
@@ -6157,13 +6169,22 @@ static int llama_decode_internal(
                     continue;
                 }
                 memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
+#ifndef NDEBUG
+                logits_valid[i] = true;
+#endif
             }
         } else if (lctx.logits_all) {
             logits_out.resize(n_vocab * n_tokens);
             memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
+#ifndef NDEBUG
+            std::fill(logits_valid.begin(), logits_valid.end(), true);
+#endif
         } else {
             logits_out.resize(n_vocab);
             memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
+#ifndef NDEBUG
+            logits_valid[n_tokens - 1] = true;
+#endif
         }
     }
 
@@ -10052,6 +10073,7 @@ float * llama_get_logits(struct llama_context * ctx) {
 }
 
 float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
+    assert(ctx->logits_valid.at(i));
     return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
 }