]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
grammar : pre-computed pieces + reserve mem + less string copies (#4330)
authorMarcus Dunn <redacted>
Tue, 5 Dec 2023 20:55:12 +0000 (10:55 -1000)
committerGitHub <redacted>
Tue, 5 Dec 2023 20:55:12 +0000 (22:55 +0200)
* reserve space for codepoints

* improvement for the appended 0

* used precomputed token text for grammar sample

* reserve canidates_decoded

* reserve canidates_grammar

* remove candidates_decoded

* Revert "remove candidates_decoded"

This reverts commit 3773328080e6a139ee83198329a13cf4ff61d707.

* changed decode_utf8 to take src by ref

llama.cpp

index b77020e10d8a5f467604041bbd54e0805ba8478d..14e5d312e6ffc08fe4f8f79cbcb898166c9f18ce 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -6851,14 +6851,13 @@ struct llama_grammar_candidate {
 // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
 // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
 static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
-        const char         * src,
-        size_t               n_src,
+        const std::string & src,
         llama_partial_utf8   partial_start) {
     static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
-    const char          * pos      = src;
+    const char          * pos      = src.c_str();
     std::vector<uint32_t> code_points;
     // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
-    code_points.reserve(n_src + 1);
+    code_points.reserve(src.size() + 1);
     uint32_t              value    = partial_start.value;
     int                   n_remain = partial_start.n_remain;
 
@@ -6909,13 +6908,6 @@ static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
     return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
 }
 
-static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
-        std::string src,
-        llama_partial_utf8 partial_start
-) {
-    return decode_utf8(src.c_str(), src.size(), partial_start);
-}
-
 // returns true iff pos points to the end of one of the definitions of a rule
 static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
     switch (pos->type) {
@@ -7554,11 +7546,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
     const llama_token eos = llama_token_eos(&ctx->model);
 
     std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
+    candidates_decoded.reserve(candidates->size);
     std::vector<llama_grammar_candidate>                              candidates_grammar;
+    candidates_grammar.reserve(candidates->size);
 
     for (size_t i = 0; i < candidates->size; ++i) {
         const llama_token id    = candidates->data[i].id;
-        const std::string piece = llama_token_to_piece(ctx, id);
+        const std::string & piece = ctx->model.vocab.id_to_token[id].text;
         if (id == eos) {
             if (!allow_eos) {
                 candidates->data[i].logit = -INFINITY;
@@ -7770,7 +7764,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
         GGML_ASSERT(false);
     }
 
-    const std::string piece = llama_token_to_piece(ctx, token);
+    const std::string & piece = ctx->model.vocab.id_to_token[token].text;
 
     // Note terminating 0 in decoded string
     const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);