]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Fix unicode in grammars (fixes #2501) (#2553)
authorEvan Jones <redacted>
Thu, 17 Aug 2023 23:54:44 +0000 (19:54 -0400)
committerGitHub <redacted>
Thu, 17 Aug 2023 23:54:44 +0000 (19:54 -0400)
* Fix unicode in grammars (fixes #2501)

* add more comments

* fix test-llama-grammar

llama.cpp
tests/test-llama-grammar.cpp

index b8cc229427620add977f403600ee0c48ae60efdf..e02b60596406a4c9ff8afabb9360d1d33c8e9294 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2077,37 +2077,81 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
 // grammar - internal
 //
 
+struct llama_partial_utf8 {
+    uint32_t value;    // bit value so far (unshifted)
+    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
+};
+
 struct llama_grammar {
     const std::vector<std::vector<llama_grammar_element>>   rules;
     std::vector<std::vector<const llama_grammar_element *>> stacks;
+
+    // buffer for partially generated UTF-8 sequence from accepted tokens
+    llama_partial_utf8                                      partial_utf8;
 };
 
 struct llama_grammar_candidate {
-    size_t           index;
-    const uint32_t * code_points;
+    size_t               index;
+    const uint32_t     * code_points;
+    llama_partial_utf8   partial_utf8;
 };
 
-// NOTE: assumes valid utf8 (but checks for overrun)
-// adds a terminating 0 for use as pointer
-std::vector<uint32_t> decode_utf8(const char * src) {
-    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
+// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
+std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
+        const char         * src,
+        llama_partial_utf8   partial_start) {
+    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
     const char          * pos      = src;
     std::vector<uint32_t> code_points;
+    uint32_t              value    = partial_start.value;
+    int                   n_remain = partial_start.n_remain;
+
+    // continue previous decode, if applicable
+    while (*pos != 0 && n_remain > 0) {
+        uint8_t next_byte = static_cast<uint8_t>(*pos);
+        if ((next_byte >> 6) != 2) {
+            // invalid sequence, abort
+            code_points.push_back(0);
+            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
+        }
+        value = (value << 6) + (next_byte & 0x3F);
+        ++pos;
+        --n_remain;
+    }
+
+    if (partial_start.n_remain > 0 && n_remain == 0) {
+        code_points.push_back(value);
+    }
+
+    // decode any subsequent utf-8 sequences, which may end in an incomplete one
     while (*pos != 0) {
         uint8_t  first_byte = static_cast<uint8_t>(*pos);
         uint8_t  highbits   = first_byte >> 4;
-        int      len        = lookup[highbits];
-        uint8_t  mask       = (1 << (8 - len)) - 1;
-        uint32_t value      = first_byte & mask;
-        const char * end    = pos + len; // may overrun!
+                 n_remain   = lookup[highbits] - 1;
+
+        if (n_remain < 0) {
+            // invalid sequence, abort
+            code_points.clear();
+            code_points.push_back(0);
+            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
+        }
+
+        uint8_t  mask       = (1 << (7 - n_remain)) - 1;
+                 value      = first_byte & mask;
         ++pos;
-        for ( ; pos < end && *pos != 0; ++pos) {
+        while (*pos != 0 && n_remain > 0) {
             value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+            ++pos;
+            --n_remain;
+        }
+        if (n_remain == 0) {
+            code_points.push_back(value);
         }
-        code_points.push_back(value);
     }
     code_points.push_back(0);
-    return code_points;
+
+    return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
 }
 
 // returns true iff pos points to the end of one of the definitions of a rule
@@ -2144,6 +2188,56 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
     return std::make_pair(found == is_positive_char, pos);
 }
 
+// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
+// range at pos (regular or inverse range)
+// asserts that pos is pointing to a char range element
+static bool llama_grammar_match_partial_char(
+        const llama_grammar_element * pos,
+        const llama_partial_utf8      partial_utf8) {
+
+    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
+    LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
+
+    uint32_t partial_value = partial_utf8.value;
+    int      n_remain      = partial_utf8.n_remain;
+
+    // invalid sequence or 7-bit char split across 2 bytes (overlong)
+    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
+        return false;
+    }
+
+    // range of possible code points this partial UTF-8 sequence could complete to
+    uint32_t low  = partial_value << (n_remain * 6);
+    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
+
+    if (low == 0) {
+        if (n_remain == 2) {
+            low = 1 << 11;
+        } else if (n_remain == 3) {
+            low = 1 << 16;
+        }
+    }
+
+    do {
+        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
+            // inclusive range, e.g. [a-z]
+            if (pos->value <= high && low <= pos[1].value) {
+                return is_positive_char;
+            }
+            pos += 2;
+        } else {
+            // exact char match, e.g. [a] or "a"
+            if (low <= pos->value && pos->value <= high) {
+                return is_positive_char;
+            }
+            pos += 1;
+        }
+    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
+
+    return !is_positive_char;
+}
+
+
 // transforms a grammar pushdown stack into N possible stacks, all ending
 // at a character range (terminal element)
 static void llama_grammar_advance_stack(
@@ -2244,8 +2338,11 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
     std::vector<llama_grammar_candidate> rejects;
 
     if (stack.empty()) {
-        // accept nothing; EOS is handled elsewhere
-        rejects.insert(rejects.end(), candidates.begin(), candidates.end());
+        for (auto tok : candidates) {
+            if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
+                rejects.push_back(tok);
+            }
+        }
         return rejects;
     }
 
@@ -2253,10 +2350,15 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
 
     std::vector<llama_grammar_candidate> next_candidates;
     for (auto tok : candidates) {
-        if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) {
-            if (tok.code_points[1] != 0) {
-                next_candidates.push_back({ tok.index, tok.code_points + 1 });
+        if (*tok.code_points == 0) {
+            // reached end of full codepoints in token, reject iff it ended in a partial sequence
+            // that cannot satisfy this position in grammar
+            if (tok.partial_utf8.n_remain != 0 &&
+                    !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
+                rejects.push_back(tok);
             }
+        } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
+            next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
         } else {
             rejects.push_back(tok);
         }
@@ -2274,7 +2376,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
 
     auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
     for (auto tok : next_rejects) {
-        rejects.push_back({ tok.index, tok.code_points - 1 });
+        rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
     }
 
     return rejects;
@@ -2339,7 +2441,7 @@ struct llama_grammar * llama_grammar_init(
         }
     } while (true);
 
-    return new llama_grammar{ std::move(vec_rules), std::move(stacks) };
+    return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
 }
 
 void llama_grammar_free(struct llama_grammar * grammar) {
@@ -2645,8 +2747,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
 
     const llama_token eos = llama_token_eos();
 
-    std::vector<std::vector<uint32_t>>   candidates_decoded;
-    std::vector<llama_grammar_candidate> candidates_grammar;
+    std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
+    std::vector<llama_grammar_candidate>                              candidates_grammar;
 
     for (size_t i = 0; i < candidates->size; ++i) {
         const llama_token id  = candidates->data[i].id;
@@ -2658,8 +2760,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
         } else if (*str == 0) {
             candidates->data[i].logit = -INFINITY;
         } else {
-            candidates_decoded.push_back(decode_utf8(str));
-            candidates_grammar.push_back({ i, candidates_decoded.back().data() });
+            candidates_decoded.push_back(decode_utf8(str, grammar->partial_utf8));
+            candidates_grammar.push_back({
+                i, candidates_decoded.back().first.data(), candidates_decoded.back().second
+            });
         }
     }
 
@@ -2860,11 +2964,14 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
     }
 
     const char * str = llama_token_to_str(ctx, token);
+
     // Note terminating 0 in decoded string
-    auto code_points = decode_utf8(str);
+    const auto   decoded     = decode_utf8(str, grammar->partial_utf8);
+    const auto & code_points = decoded.first;
     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
         grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
     }
+    grammar->partial_utf8 = decoded.second;
     LLAMA_ASSERT(!grammar->stacks.empty());
 
     ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
index f98c6531fa7b10c1293a5508ddf0eb78ca4cd7db..81c31e9e2e5d25862925361b8b687ad758767433 100644 (file)
@@ -199,7 +199,7 @@ int main()
         uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
         cp[0] = 37 + i;
         cp[1] = 0;
-        next_candidates[i] = {i, cp};
+        next_candidates[i] = {i, cp, {}};
     }
 
     std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {