]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : grammar `reserve` space in `decode_utf8` (#4210)
authorMarcus Dunn <redacted>
Sat, 25 Nov 2023 16:58:23 +0000 (08:58 -0800)
committerGitHub <redacted>
Sat, 25 Nov 2023 16:58:23 +0000 (18:58 +0200)
* reserve space for codepoints

* improvement for the appended 0

llama.cpp

index c5f4053f2fffbe196ce091db0a45e95798c1444d..f2b5967d791e92d4c9f0ba5b61da4330622dc0f1 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -6420,10 +6420,13 @@ struct llama_grammar_candidate {
 // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
 static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
         const char         * src,
+        size_t               n_src,
         llama_partial_utf8   partial_start) {
     static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
     const char          * pos      = src;
     std::vector<uint32_t> code_points;
+    // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
+    code_points.reserve(n_src + 1);
     uint32_t              value    = partial_start.value;
     int                   n_remain = partial_start.n_remain;
 
@@ -6474,6 +6477,13 @@ static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
     return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
 }
 
+static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
+        std::string src,
+        llama_partial_utf8 partial_start
+) {
+    return decode_utf8(src.c_str(), src.size(), partial_start);
+}
+
 // returns true iff pos points to the end of one of the definitions of a rule
 static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
     switch (pos->type) {
@@ -7123,7 +7133,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
         } else if (piece.empty() || piece[0] == 0) {
             candidates->data[i].logit = -INFINITY;
         } else {
-            candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
+            candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
             candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
         }
     }
@@ -7330,7 +7340,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
     const std::string piece = llama_token_to_piece(ctx, token);
 
     // Note terminating 0 in decoded string
-    const auto   decoded     = decode_utf8(piece.c_str(), grammar->partial_utf8);
+    const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);
     const auto & code_points = decoded.first;
     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
         grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);