]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add llama_beam_search() (#2267)
authorMatt Pulver <redacted>
Fri, 25 Aug 2023 15:18:48 +0000 (11:18 -0400)
committerGitHub <redacted>
Fri, 25 Aug 2023 15:18:48 +0000 (18:18 +0300)
* Add llama_beam_search().

* Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token().

* Add space around * pointers and & references.

* Add spaces around comparison and assignment operators.

* Prefer west const.

* Use llama_ prefix for structs in global namespace.

* Delete obsolete comment from an earlier revision.

* Change eos to eob in llama_beam and llama_beam_view structs.

common/common.h
examples/CMakeLists.txt
examples/beam_search/CMakeLists.txt [new file with mode: 0644]
examples/beam_search/beam_search.cpp [new file with mode: 0644]
examples/server/server.cpp
llama.cpp
llama.h

index 17d271e6750e27210151b06457fb3eb3eb500187..ce61265f8c12472427ac83be17653f9cb980d189 100644 (file)
@@ -28,6 +28,7 @@ struct gpt_params {
     int32_t main_gpu                        = 0;    // the GPU that is used for scratch and small tensors
     float   tensor_split[LLAMA_MAX_DEVICES] = {0};  // how split tensors should be distributed across GPUs
     int32_t n_probs                         = 0;    // if greater than 0, output the probabilities of top n_probs tokens.
+    int32_t n_beams                         = 0;    // if non-zero then use beam search of given width.
     float   rope_freq_base                  = 10000.0f; // RoPE base frequency
     float   rope_freq_scale                 = 1.0f;     // RoPE frequency scaling factor
 
index d2176c910c299120284819d0a88b0d6b59e9591d..94b78522487484cd15494336fcddd9fca544a8bd 100644 (file)
@@ -25,6 +25,7 @@ else()
     add_subdirectory(simple)
     add_subdirectory(embd-input)
     add_subdirectory(llama-bench)
+    add_subdirectory(beam_search)
     if (LLAMA_METAL)
         add_subdirectory(metal)
     endif()
diff --git a/examples/beam_search/CMakeLists.txt b/examples/beam_search/CMakeLists.txt
new file mode 100644 (file)
index 0000000..b29e010
--- /dev/null
@@ -0,0 +1,8 @@
+set(TARGET beam_search)
+add_executable(${TARGET} beam_search.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
+if(TARGET BUILD_INFO)
+  add_dependencies(${TARGET} BUILD_INFO)
+endif()
diff --git a/examples/beam_search/beam_search.cpp b/examples/beam_search/beam_search.cpp
new file mode 100644 (file)
index 0000000..1c04fab
--- /dev/null
@@ -0,0 +1,188 @@
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include "common.h"
+#include "llama.h"
+#include "build-info.h"
+
+#include <cassert>
+#include <cinttypes>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <ctime>
+#include <fstream>
+#include <iostream>
+#include <string>
+#include <vector>
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+#include <signal.h>
+#include <unistd.h>
+#elif defined (_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#define NOMINMAX
+#include <windows.h>
+#include <signal.h>
+#endif
+
+// Used for debugging to print out beam tokens.
+struct ostream_beam_view {
+    llama_context * ctx;
+    llama_beam_view beam_view;
+};
+std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
+    os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
+    for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
+        os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
+    }
+    return os << ')';
+}
+
+// Put here anything you want back in beam_search_callback().
+struct beam_search_callback_data {
+    llama_context * ctx;
+    std::vector<llama_token> response;
+};
+
+// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
+// For example, eob can be flagged due to maximum token length, stop words, etc.
+bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) {
+    return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
+}
+
+// Function matching type llama_beam_search_callback_fn_t.
+// Custom callback example is called each time the beams lengths increase:
+//  * Show progress by printing ',' following by number of convergent beam tokens if any.
+//  * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
+//    This is also called when the stop condition is met.
+//    Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
+void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_state) {
+    auto& callback_data = *static_cast<beam_search_callback_data*>(callback_data_ptr);
+    // Mark beams as EOS as needed.
+    for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
+        llama_beam_view& beam_view = beams_state.beam_views[i];
+        if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) {
+            beam_view.eob = true;
+        }
+    }
+    printf(",");  // Show progress
+    if (const size_t n = beams_state.common_prefix_length) {
+        callback_data.response.resize(callback_data.response.size() + n);
+        assert(0u < beams_state.n_beams);
+        const llama_token * tokens = beams_state.beam_views[0].tokens;
+        std::copy(tokens, tokens + n, callback_data.response.end() - n);
+        printf("%lu", n);
+    }
+    fflush(stdout);
+#if 1 // DEBUG: print current beams for this iteration
+    std::cout << "\n\nCurrent beams (last_call=" << beams_state.last_call << "):\n";
+    for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
+        std::cout << "beams["<<i<<"]: " << ostream_beam_view{callback_data.ctx,beams_state.beam_views[i]} << std::endl;
+    }
+#endif
+}
+
+int main(int argc, char ** argv)
+{
+    gpt_params params;
+    //params.n_gpu_layers = 200;
+
+    //---------------------------------
+    // Print help :
+    //---------------------------------
+
+    if ( argc < 2 || argv[1][0] == '-' )
+    {
+        printf( "Usage: %s MODEL_PATH [BEAM_WIDTH=2] [PROMPT]\n" , argv[0] );
+        return 1 ;
+    }
+
+    //---------------------------------
+    // Load parameters :
+    //---------------------------------
+
+    params.model = argv[1];
+
+    params.n_beams = 2 < argc ? std::stoi(argv[2]) : 2;
+
+    if ( argc > 3 )
+    {
+        params.prompt = argv[3];
+    }
+
+    if ( params.prompt.empty() )
+    {
+        params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n";
+    }
+
+    //---------------------------------
+    // Init LLM :
+    //---------------------------------
+
+    llama_backend_init(params.numa);
+
+    llama_model * model;
+    llama_context * ctx;
+
+    std::tie(model, ctx) = llama_init_from_gpt_params( params );
+
+    if ( model == NULL )
+    {
+        fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
+        return 1;
+    }
+
+    //---------------------------------
+    // Tokenize the prompt :
+    //---------------------------------
+
+    std::vector<llama_token> tokens_list = llama_tokenize(ctx, params.prompt, true);
+
+    const size_t max_context_size     = llama_n_ctx( ctx );
+    const size_t max_tokens_list_size = max_context_size - 4 ;
+
+    if (tokens_list.size() > max_tokens_list_size)
+    {
+        fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" ,
+             __func__ , tokens_list.size() , max_tokens_list_size );
+        return 1;
+    }
+
+    fprintf( stderr, "\n\n" );
+
+    // Print the tokens from the prompt :
+
+    for( auto id : tokens_list )
+    {
+        std::cout << llama_token_to_str(ctx, id);
+    }
+    std::cout << std::flush;
+
+    int n_past = llama_get_kv_cache_token_count(ctx);
+    if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
+    {
+        fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
+        return 1;
+    }
+    n_past += tokens_list.size();
+
+    beam_search_callback_data callback_data{ctx, {}};
+    size_t const beam_width = static_cast<size_t>(params.n_beams);
+    int const n_predict = 256;
+    llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads);
+
+    std::cout << "\n\n";
+    for (llama_token const token_id : callback_data.response) {
+        std::cout << llama_token_to_str(ctx,token_id);
+    }
+    std::cout << std::endl;
+
+    llama_free( ctx );
+    llama_free_model( model );
+
+    llama_backend_free();
+
+    return 0;
+}
index 025b385cc8b1e8718174f25ecc44ee7dccbe14fa..3300553f9b397ab6519a43c63063de485735f674 100644 (file)
@@ -1209,6 +1209,62 @@ static void log_server_request(const Request &req, const Response &res)
                            });
 }
 
+bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) {
+    return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
+}
+
+// Function matching type llama_beam_search_callback_fn_t.
+// Custom callback example is called each time the beams lengths increase:
+//  * Show progress by printing ',' following by number of convergent beam tokens if any.
+//  * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
+//    This is also called when the stop condition is met.
+//    Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
+void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
+    auto & llama = *static_cast<llama_server_context*>(callback_data);
+    // Mark beams as EOS as needed.
+    for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
+        llama_beam_view& beam_view = beams_state.beam_views[i];
+        if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
+            beam_view.eob = true;
+        }
+    }
+    printf(",");  // Show progress
+    if (const size_t n = beams_state.common_prefix_length) {
+        llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
+        assert(0u < beams_state.n_beams);
+        const llama_token * tokens = beams_state.beam_views[0].tokens;
+        const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
+        std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
+        printf("%lu", n);
+    }
+    fflush(stdout);
+#if 0 // DEBUG: print current beams for this iteration
+    std::cout << "\n\nCurrent beams:\n";
+    for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
+        std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
+    }
+#endif
+}
+
+struct token_translator {
+    llama_context * ctx;
+    std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
+    std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); }
+};
+
+void append_to_generated_text_from_generated_token_probs(llama_server_context & llama) {
+    auto & gtps = llama.generated_token_probs;
+    auto translator = token_translator{llama.ctx};
+    auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); };
+    const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
+    if (llama.generated_text.capacity() < llama.generated_text.size() + len) {
+        llama.generated_text.reserve(llama.generated_text.size() + len);
+    }
+    for (const completion_token_output & cto : gtps) {
+        llama.generated_text += translator(cto);
+    }
+}
+
 int main(int argc, char **argv)
 {
     // own arguments required by this example
@@ -1291,22 +1347,30 @@ int main(int argc, char **argv)
         llama.beginCompletion();
 
         if (!llama.stream) {
-            size_t stop_pos = std::string::npos;
+            if (llama.params.n_beams) {
+                // Fill llama.generated_token_probs vector with final beam.
+                llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
+                                  llama.n_past, llama.n_remain, llama.params.n_threads);
+                // Translate llama.generated_token_probs to llama.generated_text.
+                append_to_generated_text_from_generated_token_probs(llama);
+            } else {
+                size_t stop_pos = std::string::npos;
 
-            while (llama.has_next_token) {
-                const completion_token_output token_with_probs = llama.doCompletion();
-                const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
+                while (llama.has_next_token) {
+                    const completion_token_output token_with_probs = llama.doCompletion();
+                    const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
 
-                stop_pos = llama.findStoppingStrings(llama.generated_text,
-                    token_text.size(), STOP_FULL);
-            }
+                    stop_pos = llama.findStoppingStrings(llama.generated_text,
+                        token_text.size(), STOP_FULL);
+                }
 
-            if (stop_pos == std::string::npos) {
-                stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
-            }
-            if (stop_pos != std::string::npos) {
-                llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
-                    llama.generated_text.end());
+                if (stop_pos == std::string::npos) {
+                    stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
+                }
+                if (stop_pos != std::string::npos) {
+                    llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
+                        llama.generated_text.end());
+                }
             }
 
             const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);
index 4529ac82288549f2cebab140e810c11bc27f48b6..7d8b9a0ac485b114227b24f56005e57b5484d518 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -4326,6 +4326,257 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
     ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
 }
 
+//
+// Beam search
+//
+
+struct llama_beam {
+    std::vector<llama_token> tokens;
+    float p;  // Cumulative beam probability (renormalized relative to all beams)
+    bool eob; // Initialize end-of-beam to false. Callback sets this to true.
+    // Sort beams by probability. In case of ties, prefer beams at eob.
+    bool operator<(const llama_beam & rhs) const {
+        return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob);
+    }
+    // Shift off first n tokens and discard them.
+    void shift_tokens(const size_t n) {
+        if (n) {
+            std::copy(tokens.begin() + n, tokens.end(), tokens.begin());
+            tokens.resize(tokens.size() - n);
+        }
+    }
+    llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; }
+};
+
+// A struct for calculating logit-related info.
+struct llama_logit_info {
+    const float * const logits;
+    const int n_vocab;
+    const float max_l;
+    const float normalizer;
+    struct sum_exp {
+        float max_l;
+        float operator()(float sum, float l) const { return sum + std::exp(l - max_l); }
+    };
+    llama_logit_info(llama_context * ctx)
+      : logits(llama_get_logits(ctx))
+      , n_vocab(llama_n_vocab(ctx))
+      , max_l(*std::max_element(logits, logits + n_vocab))
+      , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l}))
+      { }
+    llama_token_data get_token_data(const llama_token token_id) const {
+        constexpr auto p = std::numeric_limits<float>::quiet_NaN();  // never used
+        return {token_id, logits[token_id], p};
+    }
+    // Return top k token_data by logit.
+    std::vector<llama_token_data> top_k(size_t k) {
+        std::vector<llama_token_data> min_heap;  // min-heap by logit
+        const llama_token k_min = std::min(static_cast<llama_token>(k), n_vocab);
+        min_heap.reserve(k_min);
+        for (llama_token token_id = 0 ; token_id < k_min ; ++token_id) {
+            min_heap.push_back(get_token_data(token_id));
+        }
+        auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; };
+        std::make_heap(min_heap.begin(), min_heap.end(), comp);
+        for (llama_token token_id = k_min ; token_id < n_vocab ; ++token_id) {
+            if (min_heap.front().logit < logits[token_id]) {
+                std::pop_heap(min_heap.begin(), min_heap.end(), comp);
+                min_heap.back().id = token_id;
+                min_heap.back().logit = logits[token_id];
+                std::push_heap(min_heap.begin(), min_heap.end(), comp);
+            }
+        }
+        return min_heap;
+    }
+    float probability_from_logit(float logit) {
+        return normalizer * std::exp(logit - max_l);
+    }
+};
+
+struct llama_beam_search_data {
+    llama_context * ctx;
+    size_t n_beams;
+    int n_past;
+    int n_predict;
+    int n_threads;
+    std::vector<llama_beam> beams;
+    std::vector<llama_beam> next_beams;
+
+    // Re-calculated on each loop iteration
+    size_t common_prefix_length;
+
+    // Used to communicate to/from callback on beams state.
+    std::vector<llama_beam_view> beam_views;
+
+    llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
+      : ctx(ctx)
+      , n_beams(n_beams)
+      , n_past(n_past)
+      , n_predict(n_predict)
+      , n_threads(n_threads)
+      , beam_views(n_beams) {
+        beams.reserve(n_beams);
+        next_beams.reserve(n_beams);
+    }
+
+    // Collapse beams to a single beam given by index.
+    void collapse_beams(const size_t beam_idx) {
+        if (0u < beam_idx) {
+            std::swap(beams[0], beams[beam_idx]);
+        }
+        beams.resize(1);
+    }
+
+    // Min-heaps are used to efficiently collect the top-k elements (k=n_beams).
+    // The repetative patterns below reflect the 2 stages of heaps:
+    //  * Gather elements until the vector is full, then call std::make_heap() on it.
+    //  * If the heap is full and a new element is found that should be included, pop the
+    //    least element to the back(), replace it with the new, then push it into the heap.
+    void fill_next_beams_by_top_probabilities(llama_beam & beam) {
+        // Min-heaps use a greater-than comparator.
+        const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; };
+        if (beam.eob) {
+            // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
+            if (next_beams.size() < n_beams) {
+                next_beams.push_back(std::move(beam));
+                if (next_beams.size() == n_beams) {
+                    std::make_heap(next_beams.begin(), next_beams.end(), comp);
+                }
+            } else if (next_beams.front().p < beam.p) {
+                std::pop_heap(next_beams.begin(), next_beams.end(), comp);
+                next_beams.back() = std::move(beam);
+                std::push_heap(next_beams.begin(), next_beams.end(), comp);
+            }
+        } else {
+            // beam is not at end-of-sentence, so branch with next top_k tokens.
+            if (!beam.tokens.empty()) {
+                llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads);
+            }
+            llama_logit_info logit_info(ctx);
+            std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
+            size_t i=0;
+            if (next_beams.size() < n_beams) {
+                for (; next_beams.size() < n_beams ; ++i) {
+                    llama_beam next_beam = beam;
+                    next_beam.tokens.push_back(next_tokens[i].id);
+                    next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit);
+                    next_beams.push_back(std::move(next_beam));
+                }
+                std::make_heap(next_beams.begin(), next_beams.end(), comp);
+            } else {
+                for (; next_beams.front().p == 0.0f ; ++i) {
+                    std::pop_heap(next_beams.begin(), next_beams.end(), comp);
+                    next_beams.back() = beam;
+                    next_beams.back().tokens.push_back(next_tokens[i].id);
+                    next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit);
+                    std::push_heap(next_beams.begin(), next_beams.end(), comp);
+                }
+            }
+            for (; i < n_beams ; ++i) {
+                const float next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit);
+                if (next_beams.front().p < next_p) {
+                    std::pop_heap(next_beams.begin(), next_beams.end(), comp);
+                    next_beams.back() = beam;
+                    next_beams.back().tokens.push_back(next_tokens[i].id);
+                    next_beams.back().p = next_p;
+                    std::push_heap(next_beams.begin(), next_beams.end(), comp);
+                }
+            }
+        }
+    }
+
+    // Find common_prefix_length based on beams.
+    // Requires beams is not empty.
+    size_t find_common_prefix_length() {
+        size_t common_prefix_length = beams[0].tokens.size();
+        for (size_t i = 1 ; i < beams.size() ; ++i) {
+            common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size());
+            for (size_t j = 0 ; j < common_prefix_length ; ++j) {
+                if (beams[0].tokens[j] != beams[i].tokens[j]) {
+                    common_prefix_length = j;
+                    break;
+                }
+            }
+        }
+        return common_prefix_length;
+    }
+
+    // Construct beams_state to send back to caller via the callback function.
+    // Side effect: set common_prefix_length = find_common_prefix_length();
+    llama_beams_state get_beams_state(const bool last_call) {
+        for (size_t i = 0 ; i < beams.size() ; ++i) {
+            beam_views[i] = beams[i].view();
+        }
+        common_prefix_length = find_common_prefix_length();
+        return {beam_views.data(), beams.size(), common_prefix_length, last_call};
+    }
+
+    // Loop:
+    //  * while i < n_predict, AND
+    //  * any of the beams have not yet reached end-of-beam (eob), AND
+    //  * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence
+    //    (since all other beam probabilities can only decrease)
+    void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) {
+        beams.push_back({{}, 1.0f, false});  // Start with one empty beam w/ probability = 1.0 and !eob.
+        const auto not_eob = [](const llama_beam & beam) { return !beam.eob; };
+        for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) &&
+                       !beams[top_beam_index()].eob ; ++i) {
+            callback(callback_data, get_beams_state(false));  // Sets common_prefix_length
+            update_beams_from_beam_views();   // Update values (p,eob) that callback may have changed.
+            if (common_prefix_length) {
+                llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads);
+                n_past += common_prefix_length;
+            }
+            // Zero-out next_beam probabilities to place them last in following min-heap.
+            std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam & beam) { beam.p = 0.0f; });
+            for (llama_beam & beam : beams) {
+                beam.shift_tokens(common_prefix_length);
+                fill_next_beams_by_top_probabilities(beam);
+            }
+            // next_beams become the beams of next/final iteration. Swap them to re-use memory.
+            beams.swap(next_beams);
+            renormalize_beam_probabilities(beams);
+        }
+        collapse_beams(top_beam_index());
+        callback(callback_data, get_beams_state(true));
+    }
+
+    // As beams grow, the cumulative probabilities decrease.
+    // Renormalize them to avoid floating point underflow.
+    static void renormalize_beam_probabilities(std::vector<llama_beam> & beams) {
+        const auto sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; };
+        const float inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p);
+        std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; });
+    }
+
+    // Assumes beams is non-empty.  Uses llama_beam::operator<() for ordering.
+    size_t top_beam_index() {
+        return std::max_element(beams.begin(), beams.end()) - beams.begin();
+    }
+
+    // Copy (p,eob) for each beam which may have been changed by the callback.
+    void update_beams_from_beam_views() {
+        for (size_t i = 0 ; i < beams.size() ; ++i) {
+            beams[i].p = beam_views[i].p;
+            beams[i].eob = beam_views[i].eob;
+        }
+    }
+};
+
+void llama_beam_search(llama_context * ctx,
+                       llama_beam_search_callback_fn_t callback, void * callback_data,
+                       size_t n_beams, int n_past, int n_predict, int n_threads) {
+    assert(ctx);
+    const int64_t t_start_sample_us = ggml_time_us();
+
+    llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads);
+
+    beam_search_data.loop(callback, callback_data);
+
+    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+    ctx->n_sample++;
+}
+
 //
 // quantization
 //
diff --git a/llama.h b/llama.h
index d474681725ff8389d31f7127e603cb3f16c23f0b..86737200fe3497e1ddb78ad45e6a313e95482867 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -469,6 +469,43 @@ extern "C" {
     /// @details Accepts the sampled token into the grammar
     LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
 
+    //
+    // Beam search
+    //
+
+    struct llama_beam_view {
+        const llama_token * tokens;
+        size_t n_tokens;
+        float p;   // Cumulative beam probability (renormalized relative to all beams)
+        bool eob;  // Callback should set this to true when a beam is at end-of-beam.
+    };
+
+    // Passed to beam_search_callback function.
+    // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
+    // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
+    // These pointers are valid only during the synchronous callback, so should not be saved.
+    struct llama_beams_state {
+        llama_beam_view * beam_views;
+        size_t n_beams;               // Number of elements in beam_views[].
+        size_t common_prefix_length;  // Current max length of prefix tokens shared by all beams.
+        bool last_call;               // True iff this is the last callback invocation.
+    };
+
+    // Type of pointer to the beam_search_callback function.
+    // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
+    // passed back to beam_search_callback. This avoids having to use global variables in the callback.
+    typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, llama_beams_state);
+
+    /// @details Deterministically returns entire sentence constructed by a beam search.
+    /// @param ctx Pointer to the llama_context.
+    /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
+    /// @param callback_data A pointer that is simply passed back to callback.
+    /// @param n_beams Number of beams to use.
+    /// @param n_past Number of tokens already evaluated.
+    /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
+    /// @param n_threads Number of threads as passed to llama_eval().
+    LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
+
     // Performance information
     LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
     LLAMA_API void llama_print_timings(struct llama_context * ctx);