]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : samplers accept the prompt correctly (#10019)
authorwwoodsTM <redacted>
Wed, 23 Oct 2024 19:27:51 +0000 (13:27 -0600)
committerGitHub <redacted>
Wed, 23 Oct 2024 19:27:51 +0000 (22:27 +0300)
examples/server/server.cpp

index 3992108e7f38311696491ed7c0c978df08a33ace..51f30ffeab9808263e9c00c13651c309d1482770 100644 (file)
@@ -2163,17 +2163,10 @@ struct server_context {
                                 GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
                             }
 
-                            common_sampler_reset(slot.smpl);
-
                             if (slot.params.cache_prompt) {
                                 // reuse any previously computed tokens that are common with the new prompt
                                 slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
 
-                                // push the prompt into the sampling context (do not apply grammar)
-                                for (int i = 0; i < slot.n_past; ++i) {
-                                    common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
-                                }
-
                                 // reuse chunks from the cached prompt by shifting their KV cache in the new position
                                 if (params.n_cache_reuse > 0) {
                                     size_t head_c = slot.n_past; // cache
@@ -2206,8 +2199,6 @@ struct server_context {
                                             for (size_t i = 0; i < n_match; i++) {
                                                 slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
 
-                                                common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
-
                                                 slot.n_past++;
                                             }
 
@@ -2259,8 +2250,6 @@ struct server_context {
 
                         // there is no common part left
                         slot.n_past = 0;
-
-                        common_sampler_reset(slot.smpl);
                     }
 
                     SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
@@ -2288,6 +2277,13 @@ struct server_context {
 
                         GGML_ASSERT(batch.n_tokens > 0);
 
+                        common_sampler_reset(slot.smpl);
+
+                        // Process all prompt tokens through sampler system
+                        for (int i = 0; i < slot.n_prompt_tokens; ++i) {
+                            common_sampler_accept(slot.smpl, prompt_tokens[i], false);
+                        }
+
                         // extract the logits only for the last token
                         batch.logits[batch.n_tokens - 1] = true;