]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
grammar : check the full vocab only if necessary (opt) (#4306)
authorkalomaze <redacted>
Sat, 23 Dec 2023 09:27:07 +0000 (03:27 -0600)
committerGitHub <redacted>
Sat, 23 Dec 2023 09:27:07 +0000 (11:27 +0200)
* Check the full vocab for grammar only if necessary

* Fix missing logit restoration step (?)

Does this matter, actually?

* Fix whitespace / formatting

* Adjust comment

* Didn't mean to push test gbnf

* Split sampling into the helper function (?)

And also revert the changes made to the header

* common : fix final newline

---------

Co-authored-by: Georgi Gerganov <redacted>
common/sampling.cpp

index f4e76df31bee3263d87b187e0528531285c39029..5b15204be88c48c6c9527782dac5201d8f452cd7 100644 (file)
@@ -149,11 +149,12 @@ static void sampler_queue(
     }
 }
 
-llama_token llama_sampling_sample(
+static llama_token llama_sampling_sample_impl(
                   struct llama_sampling_context * ctx_sampling,
                   struct llama_context * ctx_main,
                   struct llama_context * ctx_cfg,
-                  const int idx) {
+                  const int idx,
+                  bool is_resampling) {  // Add a parameter to indicate if we are resampling
     const llama_sampling_params & params = ctx_sampling->params;
 
     const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@@ -173,8 +174,17 @@ llama_token llama_sampling_sample(
 
     llama_token id = 0;
 
+    // Get a pointer to the logits
     float * logits = llama_get_logits_ith(ctx_main, idx);
 
+    // Declare original_logits at the beginning of the function scope
+    std::vector<float> original_logits;
+
+    if (!is_resampling) {
+        // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
+        original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
+    }
+
     // apply params.logit_bias map
     for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
         logits[it->first] += it->second;
@@ -210,7 +220,8 @@ llama_token llama_sampling_sample(
         }
     }
 
-    if (ctx_sampling->grammar != NULL) {
+    // If we are in the resampling phase, apply grammar checks before sampling logic
+    if (is_resampling && ctx_sampling->grammar != NULL) {
         llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
     }
 
@@ -252,9 +263,40 @@ llama_token llama_sampling_sample(
         }
     }
 
+    if (ctx_sampling->grammar != NULL && !is_resampling) {
+        // Create an array with a single token data element for the sampled id
+        llama_token_data single_token_data = {id, logits[id], 0.0f};
+        llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
+
+        // Apply grammar constraints to the single token
+        llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
+
+        // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
+        bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+
+        // If the token is not valid according to the grammar, perform resampling
+        if (!is_valid) {
+            LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
+
+            // Restore logits from the copy
+            std::copy(original_logits.begin(), original_logits.end(), logits);
+
+            return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true);  // Pass true for is_resampling
+        }
+    }
+
     return id;
 }
 
+llama_token llama_sampling_sample(
+                  struct llama_sampling_context * ctx_sampling,
+                  struct llama_context * ctx_main,
+                  struct llama_context * ctx_cfg,
+                  const int idx) {
+    // Call the implementation function with is_resampling set to false by default
+    return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
+}
+
 void llama_sampling_accept(
         struct llama_sampling_context * ctx_sampling,
         struct llama_context * ctx_main,