]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : reuse llama_sample_token common util (#3494)
authorJhen-Jie Hong <redacted>
Fri, 6 Oct 2023 12:44:24 +0000 (07:44 -0500)
committerGitHub <redacted>
Fri, 6 Oct 2023 12:44:24 +0000 (15:44 +0300)
* server : reuse llama_sample_token common function

* common : use n_probs for temperature sampling

common/common.cpp
examples/server/server.cpp

index 6b9b4695ca28882d5f5e5b955e723335b2b80e24..186f5b26807d850671fcc8dc6a4f7d90be71ef06 100644 (file)
@@ -1020,10 +1020,11 @@ llama_token llama_sample_token(
             id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
         } else {
             // Temperature sampling
-            llama_sample_top_k      (ctx, &cur_p, top_k, 1);
-            llama_sample_tail_free  (ctx, &cur_p, tfs_z, 1);
-            llama_sample_typical    (ctx, &cur_p, typical_p, 1);
-            llama_sample_top_p      (ctx, &cur_p, top_p, 1);
+            size_t min_keep = std::max(1, params.n_probs);
+            llama_sample_top_k      (ctx, &cur_p, top_k, min_keep);
+            llama_sample_tail_free  (ctx, &cur_p, tfs_z, min_keep);
+            llama_sample_typical    (ctx, &cur_p, typical_p, min_keep);
+            llama_sample_top_p      (ctx, &cur_p, top_p, min_keep);
             llama_sample_temp(ctx, &cur_p, temp);
 
             {
index 5f9cdecd548e2f571b0e6b54690e98890f7dd240..c53a64867336f9433417d215871f194fe2df4bd4 100644 (file)
@@ -534,98 +534,20 @@ struct llama_server_context
             return result;
         }
 
-        // out of user input, sample next token
-        const float temp = params.temp;
-        const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(model) : params.top_k;
-        const float top_p = params.top_p;
-        const float tfs_z = params.tfs_z;
-        const float typical_p = params.typical_p;
-        const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
-        const float repeat_penalty = params.repeat_penalty;
-        const float alpha_presence = params.presence_penalty;
-        const float alpha_frequency = params.frequency_penalty;
-        const int mirostat = params.mirostat;
-        const float mirostat_tau = params.mirostat_tau;
-        const float mirostat_eta = params.mirostat_eta;
-        const bool penalize_nl = params.penalize_nl;
-        const int32_t n_probs = params.n_probs;
-
         {
-            auto *logits = llama_get_logits(ctx);
-            auto n_vocab = llama_n_vocab(model);
-
-            // Apply params.logit_bias map
-            for (const auto &it : params.logit_bias)
-            {
-                logits[it.first] += it.second;
-            }
-
+            // out of user input, sample next token
             std::vector<llama_token_data> candidates;
-            candidates.reserve(n_vocab);
-            for (llama_token token_id = 0; token_id < n_vocab; token_id++)
-            {
-                candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
-            }
+            candidates.reserve(llama_n_vocab(model));
 
-            llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
-
-            // Apply penalties
-            float nl_logit = logits[llama_token_nl(ctx)];
-            auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
-            llama_sample_repetition_penalty(ctx, &candidates_p,
-                                            last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
-                                            last_n_repeat, repeat_penalty);
-            llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
-                                                          last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
-                                                          last_n_repeat, alpha_frequency, alpha_presence);
-            if (!penalize_nl)
-            {
-                logits[llama_token_nl(ctx)] = nl_logit;
-            }
+            result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates);
 
-            if (grammar != nullptr) {
-                llama_sample_grammar(ctx, &candidates_p, grammar);
-            }
+            llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 
-            if (temp <= 0)
-            {
-                // Greedy sampling
-                result.tok = llama_sample_token_greedy(ctx, &candidates_p);
-                if (n_probs > 0)
-                {
-                    llama_sample_softmax(ctx, &candidates_p);
-                }
-            }
-            else
+            const int32_t n_probs = params.n_probs;
+            if (params.temp <= 0 && n_probs > 0)
             {
-                if (mirostat == 1)
-                {
-                    static float mirostat_mu = 2.0f * mirostat_tau;
-                    const int mirostat_m = 100;
-                    llama_sample_temp(ctx, &candidates_p, temp);
-                    result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
-                }
-                else if (mirostat == 2)
-                {
-                    static float mirostat_mu = 2.0f * mirostat_tau;
-                    llama_sample_temp(ctx, &candidates_p, temp);
-                    result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
-                }
-                else
-                {
-                    // Temperature sampling
-                    size_t min_keep = std::max(1, n_probs);
-                    llama_sample_top_k(ctx, &candidates_p, top_k, min_keep);
-                    llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
-                    llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
-                    llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
-                    llama_sample_temp(ctx, &candidates_p, temp);
-                    result.tok = llama_sample_token(ctx, &candidates_p);
-                }
-            }
-
-            if (grammar != nullptr) {
-                llama_grammar_accept_token(ctx, grammar, result.tok);
+                // For llama_sample_token_greedy we need to sort candidates
+                llama_sample_softmax(ctx, &candidates_p);
             }
 
             for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)