]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
jinja : refactor token advancement (#20864)
authorSigbjørn Skjæret <redacted>
Sun, 22 Mar 2026 16:45:10 +0000 (17:45 +0100)
committerGitHub <redacted>
Sun, 22 Mar 2026 16:45:10 +0000 (17:45 +0100)
* refactor token advancement

* exercise sub-expressions

common/jinja/parser.cpp
tests/test-jinja.cpp

index 7970336ac01e141126c85a1aa9b5c2dbe23964a4..4ae4477445bbfd15bae8f268b0a797743ce819bd 100644 (file)
@@ -53,6 +53,13 @@ private:
         return tokens[current + offset];
     }
 
+    const token & next() {
+        if (current >= tokens.size()) {
+            throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos);
+        }
+        return tokens[current++];
+    }
+
     token expect(token::type type, const std::string&  error) {
         const auto & t = peek();
         if (t.t != type) {
@@ -90,9 +97,9 @@ private:
         size_t start_pos = current;
         switch (peek().t) {
             case token::comment:
-                return mk_stmt<comment_statement>(start_pos, tokens[current++].value);
+                return mk_stmt<comment_statement>(start_pos, next().value);
             case token::text:
-                return mk_stmt<string_literal>(start_pos, tokens[current++].value);
+                return mk_stmt<string_literal>(start_pos, next().value);
             case token::open_statement:
                 return parse_jinja_statement();
             case token::open_expression:
@@ -119,8 +126,7 @@ private:
         }
 
         size_t start_pos = current;
-        std::string name = peek().value;
-        current++; // consume identifier
+        std::string name = next().value;
 
         statement_ptr result;
         if (name == "set") {
@@ -202,7 +208,7 @@ private:
             // Ignore generation blocks (transformers-specific)
             // See https://github.com/huggingface/transformers/pull/30650 for more information.
             result = mk_stmt<noop_statement>(start_pos);
-            current++;
+            ++current;
 
         } else {
             throw std::runtime_error("Unknown statement: " + name);
@@ -217,7 +223,7 @@ private:
         statements body;
 
         if (is(token::equals)) {
-            current++;
+            ++current;
             value = parse_expression_sequence();
         } else {
             // parsing multiline set here
@@ -280,7 +286,7 @@ private:
         exprs.push_back(primary ? parse_primary_expression() : parse_expression());
         bool is_tuple = is(token::comma);
         while (is(token::comma)) {
-            current++; // consume comma
+            ++current; // consume comma
             exprs.push_back(primary ? parse_primary_expression() : parse_expression());
         }
         return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
@@ -290,7 +296,7 @@ private:
         // e.g., `message` in `for message in messages`
         auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
         if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
-        current++;
+        ++current; // consume 'in'
 
         // `messages` in `for message in messages`
         auto iterable = parse_expression();
@@ -305,7 +311,8 @@ private:
         }
 
         if (is_statement({"else"})) {
-            current += 2;
+            ++current; // consume {%
+            ++current; // consume 'else'
             expect(token::close_statement, "Expected %}");
             while (!is_statement({"endfor"})) {
                 alternate.push_back(parse_any());
@@ -347,7 +354,7 @@ private:
         auto left = parse_logical_and_expression();
         while (is_identifier("or")) {
             size_t start_pos = current;
-            token op = tokens[current++];
+            token op = next();
             left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
         }
         return left;
@@ -357,7 +364,7 @@ private:
         auto left = parse_logical_negation_expression();
         while (is_identifier("and")) {
             size_t start_pos = current;
-            auto op = tokens[current++];
+            auto op = next();
             left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
         }
         return left;
@@ -367,7 +374,7 @@ private:
         // Try parse unary operators
         if (is_identifier("not")) {
             size_t start_pos = current;
-            auto op = tokens[current++];
+            auto op = next();
             return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
         }
         return parse_comparison_expression();
@@ -382,11 +389,12 @@ private:
             size_t start_pos = current;
             if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
                 op = {token::identifier, "not in", tokens[current].pos};
-                current += 2;
+                ++current; // consume 'not'
+                ++current; // consume 'in'
             } else if (is_identifier("in")) {
-                op = tokens[current++];
+                op = next();
             } else if (is(token::comparison_binary_operator)) {
-                op = tokens[current++];
+                op = next();
             } else break;
             left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
         }
@@ -397,7 +405,7 @@ private:
         auto left = parse_multiplicative_expression();
         while (is(token::additive_binary_operator)) {
             size_t start_pos = current;
-            auto op = tokens[current++];
+            auto op = next();
             left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
         }
         return left;
@@ -407,7 +415,7 @@ private:
         auto left = parse_test_expression();
         while (is(token::multiplicative_binary_operator)) {
             size_t start_pos = current;
-            auto op = tokens[current++];
+            auto op = next();
             left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
         }
         return left;
@@ -417,9 +425,9 @@ private:
         auto operand = parse_filter_expression();
         while (is_identifier("is")) {
             size_t start_pos = current;
-            current++;
+            ++current; // consume 'is'
             bool negate = false;
-            if (is_identifier("not")) { current++; negate = true; }
+            if (is_identifier("not")) { ++current; negate = true; }
             auto test_id = parse_primary_expression();
             // FIXME: tests can also be expressed like this: if x is eq 3
             if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
@@ -432,7 +440,7 @@ private:
         auto operand = parse_call_member_expression();
         while (is(token::pipe)) {
             size_t start_pos = current;
-            current++;
+            ++current; // consume pipe
             auto filter = parse_primary_expression();
             if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
             operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
@@ -490,7 +498,7 @@ private:
     statement_ptr parse_member_expression(statement_ptr object) {
         size_t start_pos = current;
         while (is(token::dot) || is(token::open_square_bracket)) {
-            auto op = tokens[current++];
+            auto op = next();
             bool computed = op.t == token::open_square_bracket;
             statement_ptr prop;
             if (computed) {
@@ -536,7 +544,7 @@ private:
 
     statement_ptr parse_primary_expression() {
         size_t start_pos = current;
-        auto t = tokens[current++];
+        auto t = next();
         switch (t.t) {
             case token::numeric_literal:
                 if (t.value.find('.') != std::string::npos) {
@@ -547,7 +555,7 @@ private:
             case token::string_literal: {
                 std::string val = t.value;
                 while (is(token::string_literal)) {
-                    val += tokens[current++].value;
+                    val += next().value;
                 }
                 return mk_stmt<string_literal>(start_pos, val);
             }
@@ -562,9 +570,9 @@ private:
                 statements vals;
                 while (!is(token::close_square_bracket)) {
                     vals.push_back(parse_expression());
-                    if (is(token::comma)) current++;
+                    if (is(token::comma)) ++current;
                 }
-                current++;
+                ++current;
                 return mk_stmt<array_literal>(start_pos, std::move(vals));
             }
             case token::open_curly_bracket: {
@@ -573,9 +581,9 @@ private:
                     auto key = parse_expression();
                     expect(token::colon, "Expected :");
                     pairs.push_back({std::move(key), parse_expression()});
-                    if (is(token::comma)) current++;
+                    if (is(token::comma)) ++current;
                 }
-                current++;
+                ++current;
                 return mk_stmt<object_literal>(start_pos, std::move(pairs));
             }
             default:
index ef9c8f73c8b99fa7762753b501d185c9f33fde78..1550627bf09a8e43bd011e831d181b4a1c3a4f13 100644 (file)
@@ -2264,6 +2264,7 @@ static void test_fuzzing(testing & t) {
 
     t.test("malformed templates (should error, not crash)", [&](testing & t) {
         const std::vector<std::string> malformed = {
+            "",
             "{{ x",
             "{% if %}",
             "{% for %}",
@@ -2284,6 +2285,11 @@ static void test_fuzzing(testing & t) {
         for (const auto & tmpl : malformed) {
             t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object()));
         }
+        std::string tmpl = "{% for message in messages %}{{ message.role | string }} : {{ message.content if ('content' in message and message.content is not none) }}{% endfor %";
+        while (tmpl.length() > 0) {
+            t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object()));
+            tmpl.pop_back();
+        }
     });
 
     t.test("type coercion edge cases", [&](testing & t) {