]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix parallel speculative decoding (#10513)
authorGeorgi Gerganov <redacted>
Tue, 26 Nov 2024 11:36:40 +0000 (13:36 +0200)
committerGitHub <redacted>
Tue, 26 Nov 2024 11:36:40 +0000 (13:36 +0200)
ggml-ci

examples/server/server.cpp

index c0ea4faf77d42d6925a4d3ea9db2a848bb3ff594..9c86407c28ebacfc8183e6e6a9f0bbbdf90ed977 100644 (file)
@@ -2267,50 +2267,49 @@ struct server_context {
                     continue; // continue loop of slots
                 }
 
-                llama_token id;
+                llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
 
-                {
-                    completion_token_output result;
-
-                    id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
+                slot.i_batch = -1;
 
-                    slot.i_batch = -1;
+                common_sampler_accept(slot.smpl, id, true);
 
-                    common_sampler_accept(slot.smpl, id, true);
-
-                    slot.n_decoded += 1;
-                    if (slot.n_decoded == 1) {
-                        slot.t_start_generation = ggml_time_us();
-                        slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
-                        metrics.on_prompt_eval(slot);
-                    }
+                slot.n_decoded += 1;
+                if (slot.n_decoded == 1) {
+                    slot.t_start_generation = ggml_time_us();
+                    slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
+                    metrics.on_prompt_eval(slot);
+                }
 
-                    result.tok = id;
+                completion_token_output result;
+                result.tok = id;
 
-                    const auto * cur_p = common_sampler_get_candidates(slot.smpl);
+                const auto * cur_p = common_sampler_get_candidates(slot.smpl);
 
-                    for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
-                        result.probs.push_back({
-                            cur_p->data[i].id,
-                                i >= cur_p->size ? 0.0f : cur_p->data[i].p,
-                        });
-                    }
+                for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
+                    result.probs.push_back({
+                        cur_p->data[i].id,
+                            i >= cur_p->size ? 0.0f : cur_p->data[i].p,
+                    });
+                }
 
-                    if (!process_token(result, slot)) {
-                        // release slot because of stop condition
-                        slot.release();
-                        slot.print_timings();
-                        send_final_response(slot);
-                        metrics.on_prediction(slot);
-                        continue;
-                    }
+                if (!process_token(result, slot)) {
+                    // release slot because of stop condition
+                    slot.release();
+                    slot.print_timings();
+                    send_final_response(slot);
+                    metrics.on_prediction(slot);
+                    continue;
                 }
+            }
 
-                // check if the slot supports speculative decoding
-                if (!slot.can_speculate()) {
+            // do speculative decoding
+            for (auto & slot : slots) {
+                if (!slot.is_processing() || !slot.can_speculate()) {
                     continue;
                 }
 
+                llama_token id = slot.sampled;
+
                 struct common_speculative_params params_spec;
                 params_spec.n_draft   = slot.params.speculative.n_max;
                 params_spec.n_reuse   = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;