]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
jinja : implement mixed type object keys (#18955)
authorSigbjørn Skjæret <redacted>
Tue, 27 Jan 2026 18:50:42 +0000 (19:50 +0100)
committerGitHub <redacted>
Tue, 27 Jan 2026 18:50:42 +0000 (19:50 +0100)
* implement mixed type object keys

* add tests

* refactor

* minor fixes

* massive refactor

* add more tests

* forgotten tuples

* fix array/object is_hashable

* correct (albeit broken) jinja responses

verified with transformers

* improved hashing and equality

* refactor hash function

* more exhausive test case

* clean up

* cont

* cont (2)

* missing cstring

---------

Co-authored-by: Xuan Son Nguyen <redacted>
common/jinja/runtime.cpp
common/jinja/runtime.h
common/jinja/string.cpp
common/jinja/string.h
common/jinja/utils.h
common/jinja/value.cpp
common/jinja/value.h
tests/test-chat-template.cpp
tests/test-jinja.cpp

index e3e4ebf1ec2c4daffbf1724bf47052b16664dccc..f234d9284fc610e90bdafcd5ef5629a659cdea1c 100644 (file)
@@ -44,6 +44,12 @@ static std::string get_line_col(const std::string & source, size_t pos) {
     return "line " + std::to_string(line) + ", column " + std::to_string(col);
 }
 
+static void ensure_key_type_allowed(const value & val) {
+    if (!val->is_hashable()) {
+        throw std::runtime_error("Type: " + val->type() + " is not allowed as object key");
+    }
+}
+
 // execute with error handling
 value statement::execute(context & ctx) {
     try {
@@ -95,20 +101,10 @@ value identifier::execute_impl(context & ctx) {
 value object_literal::execute_impl(context & ctx) {
     auto obj = mk_val<value_object>();
     for (const auto & pair : val) {
-        value key_val = pair.first->execute(ctx);
-        if (!is_val<value_string>(key_val) && !is_val<value_int>(key_val)) {
-            throw std::runtime_error("Object literal: keys must be string or int values, got " + key_val->type());
-        }
-        std::string key = key_val->as_string().str();
+        value key = pair.first->execute(ctx);
         value val = pair.second->execute(ctx);
-        JJ_DEBUG("Object literal: setting key '%s' with value type %s", key.c_str(), val->type().c_str());
+        JJ_DEBUG("Object literal: setting key '%s' with value type %s", key->as_string().str().c_str(), val->type().c_str());
         obj->insert(key, val);
-
-        if (is_val<value_int>(key_val)) {
-            obj->val_obj.is_key_numeric = true;
-        } else if (obj->val_obj.is_key_numeric) {
-            throw std::runtime_error("Object literal: cannot mix numeric and non-numeric keys");
-        }
     }
     return obj;
 }
@@ -127,9 +123,9 @@ value binary_expression::execute_impl(context & ctx) {
     value right_val = right->execute(ctx);
     JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str());
     if (op.value == "==") {
-        return mk_val<value_bool>(value_compare(left_val, right_val, value_compare_op::eq));
+        return mk_val<value_bool>(*left_val == *right_val);
     } else if (op.value == "!=") {
-        return mk_val<value_bool>(!value_compare(left_val, right_val, value_compare_op::eq));
+        return mk_val<value_bool>(!(*left_val == *right_val));
     }
 
     auto workaround_concat_null_with_str = [&](value & res) -> bool {
@@ -230,7 +226,7 @@ value binary_expression::execute_impl(context & ctx) {
         auto & arr = right_val->as_array();
         bool member = false;
         for (const auto & item : arr) {
-            if (value_compare(left_val, item, value_compare_op::eq)) {
+            if (*left_val == *item) {
                 member = true;
                 break;
             }
@@ -265,10 +261,9 @@ value binary_expression::execute_impl(context & ctx) {
         }
     }
 
-    // String in object
-    if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
-        auto key = left_val->as_string().str();
-        bool has_key = right_val->has_key(key);
+    // Value key in object
+    if (is_val<value_object>(right_val)) {
+        bool has_key = right_val->has_key(left_val);
         if (op.value == "in") {
             return mk_val<value_bool>(has_key);
         } else if (op.value == "not in") {
@@ -465,14 +460,8 @@ value for_statement::execute_impl(context & ctx) {
         JJ_DEBUG("%s", "For loop over object keys");
         auto & obj = iterable_val->as_ordered_object();
         for (auto & p : obj) {
-            auto tuple = mk_val<value_array>();
-            if (iterable_val->val_obj.is_key_numeric) {
-                tuple->push_back(mk_val<value_int>(std::stoll(p.first)));
-            } else {
-                tuple->push_back(mk_val<value_string>(p.first));
-            }
-            tuple->push_back(p.second);
-            items.push_back(tuple);
+            auto tuple = mk_val<value_tuple>(p);
+            items.push_back(std::move(tuple));
         }
         if (ctx.is_get_stats) {
             iterable_val->stats.used = true;
@@ -602,11 +591,13 @@ value set_statement::execute_impl(context & ctx) {
     auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx);
 
     if (is_stmt<identifier>(assignee)) {
+        // case: {% set my_var = value %}
         auto var_name = cast_stmt<identifier>(assignee)->val;
         JJ_DEBUG("Setting global variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str());
         ctx.set_val(var_name, rhs);
 
     } else if (is_stmt<tuple_literal>(assignee)) {
+        // case: {% set a, b = value %}
         auto tuple = cast_stmt<tuple_literal>(assignee);
         if (!is_val<value_array>(rhs)) {
             throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type());
@@ -625,6 +616,7 @@ value set_statement::execute_impl(context & ctx) {
         }
 
     } else if (is_stmt<member_expression>(assignee)) {
+        // case: {% set ns.my_var = value %}
         auto member = cast_stmt<member_expression>(assignee);
         if (member->computed) {
             throw std::runtime_error("Cannot assign to computed member");
@@ -767,22 +759,22 @@ value member_expression::execute_impl(context & ctx) {
     }
 
     JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
+    ensure_key_type_allowed(property);
 
     value val = mk_val<value_undefined>("object_property");
 
     if (is_val<value_undefined>(object)) {
         JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined");
         return val;
+
     } else if (is_val<value_object>(object)) {
-        if (!is_val<value_string>(property)) {
-            throw std::runtime_error("Cannot access object with non-string: got " + property->type());
-        }
         auto key = property->as_string().str();
-        val = object->at(key, val);
+        val = object->at(property, val);
         if (is_val<value_undefined>(val)) {
             val = try_builtin_func(ctx, key, object, true);
         }
         JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
+
     } else if (is_val<value_array>(object) || is_val<value_string>(object)) {
         if (is_val<value_int>(property)) {
             int64_t index = property->as_int();
@@ -806,6 +798,7 @@ value member_expression::execute_impl(context & ctx) {
             auto key = property->as_string().str();
             JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
             val = try_builtin_func(ctx, key, object, true);
+
         } else {
             throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
         }
index dc7f4e471c1784f034b446b25a7fdd59bd687969..17a6dff5aa210167b0f6dd763cd961c7e14d7474 100644 (file)
@@ -79,18 +79,18 @@ struct context {
     }
 
     value get_val(const std::string & name) {
-        auto it = env->val_obj.unordered.find(name);
-        if (it != env->val_obj.unordered.end()) {
-            return it->second;
-        } else {
-            return mk_val<value_undefined>(name);
-        }
+        value default_val = mk_val<value_undefined>(name);
+        return env->at(name, default_val);
     }
 
     void set_val(const std::string & name, const value & val) {
         env->insert(name, val);
     }
 
+    void set_val(const value & name, const value & val) {
+        env->insert(name, val);
+    }
+
     void print_vars() const {
         printf("Context Variables:\n%s\n", value_to_json(env, 2).c_str());
     }
@@ -344,9 +344,19 @@ struct array_literal : public expression {
     }
 };
 
-struct tuple_literal : public array_literal {
-    explicit tuple_literal(statements && val) : array_literal(std::move(val)) {}
+struct tuple_literal : public expression {
+    statements val;
+    explicit tuple_literal(statements && val) : val(std::move(val)) {
+        for (const auto& item : this->val) chk_type<expression>(item);
+    }
     std::string type() const override { return "TupleLiteral"; }
+    value execute_impl(context & ctx) override {
+        auto arr = mk_val<value_array>();
+        for (const auto & item_stmt : val) {
+            arr->push_back(item_stmt->execute(ctx));
+        }
+        return mk_val<value_tuple>(std::move(arr->as_array()));
+    }
 };
 
 struct object_literal : public expression {
index 21ebde39e3e7e48e545742a2792937ee442113f7..8087e15b350284482d8c95bea6d47291eccd2797 100644 (file)
@@ -61,6 +61,12 @@ size_t string::length() const {
     return len;
 }
 
+void string::hash_update(hasher & hash) const noexcept {
+    for (const auto & part : parts) {
+        hash.update(part.val.data(), part.val.length());
+    }
+}
+
 bool string::all_parts_are_input() const {
     for (const auto & part : parts) {
         if (!part.is_input) {
index 78457f9e413e88c056464f2b95fd60ea3a1fa3f7..c4963000adb824fd2b015b3253e50b68e81852ca 100644 (file)
@@ -4,6 +4,8 @@
 #include <string>
 #include <vector>
 
+#include "utils.h"
+
 namespace jinja {
 
 // allow differentiate between user input strings and template strings
@@ -37,6 +39,7 @@ struct string {
 
     std::string str() const;
     size_t length() const;
+    void hash_update(hasher & hash) const noexcept;
     bool all_parts_are_input() const;
     bool is_uppercase() const;
     bool is_lowercase() const;
index 1e9f2a12a1a3566db06ab4d365a0c55baba4b14b..de6947fc28f3423dc3beab7cc468096b6b3d06c6 100644 (file)
@@ -3,6 +3,8 @@
 #include <string>
 #include <sstream>
 #include <algorithm>
+#include <cstdint>
+#include <cstring>
 
 namespace jinja {
 
@@ -46,4 +48,102 @@ static std::string fmt_error_with_source(const std::string & tag, const std::str
     return oss.str();
 }
 
+// Note: this is a simple hasher, not cryptographically secure, just for hash table usage
+struct hasher {
+    static constexpr auto size_t_digits = sizeof(size_t) * 8;
+    static constexpr size_t prime = size_t_digits == 64 ? 0x100000001b3 : 0x01000193;
+    static constexpr size_t seed = size_t_digits == 64 ? 0xcbf29ce484222325 : 0x811c9dc5;
+    static constexpr auto block_size = sizeof(size_t); // in bytes; allowing the compiler to vectorize the computation
+
+    static_assert(size_t_digits == 64 || size_t_digits == 32);
+    static_assert(block_size == 8 || block_size == 4);
+
+    uint8_t buffer[block_size];
+    size_t idx = 0; // current index in buffer
+    size_t state = seed;
+
+    hasher() = default;
+    hasher(const std::type_info & type_inf) noexcept {
+        const auto type_hash = type_inf.hash_code();
+        update(&type_hash, sizeof(type_hash));
+    }
+
+    // Properties:
+    //   - update is not associative: update(a).update(b) != update(b).update(a)
+    //   - update(a ~ b) == update(a).update(b) with ~ as concatenation operator --> useful for streaming
+    //   - update("", 0) --> state unchanged with empty input
+    hasher& update(void const * bytes, size_t len) noexcept {
+        const uint8_t * c = static_cast<uint8_t const *>(bytes);
+        if (len == 0) {
+            return *this;
+        }
+        size_t processed = 0;
+
+        // first, fill the existing buffer if it's partial
+        if (idx > 0) {
+            size_t to_fill = block_size - idx;
+            if (to_fill > len) {
+                to_fill = len;
+            }
+            std::memcpy(buffer + idx, c, to_fill);
+            idx += to_fill;
+            processed += to_fill;
+            if (idx == block_size) {
+                update_block(buffer);
+                idx = 0;
+            }
+        }
+
+        // process full blocks from the remaining input
+        for (; processed + block_size <= len; processed += block_size) {
+            update_block(c + processed);
+        }
+
+        // buffer any remaining bytes
+        size_t remaining = len - processed;
+        if (remaining > 0) {
+            std::memcpy(buffer, c + processed, remaining);
+            idx = remaining;
+        }
+        return *this;
+    }
+
+    // convenience function for testing only
+    hasher& update(const std::string & s) noexcept {
+        return update(s.data(), s.size());
+    }
+
+    // finalize and get the hash value
+    // note: after calling digest, the hasher state is modified, do not call update() again
+    size_t digest() noexcept {
+        // if there are remaining bytes in buffer, fill the rest with zeros and process
+        if (idx > 0) {
+            for (size_t i = idx; i < block_size; ++i) {
+                buffer[i] = 0;
+            }
+            update_block(buffer);
+            idx = 0;
+        }
+
+        return state;
+    }
+
+private:
+    // IMPORTANT: block must have at least block_size bytes
+    void update_block(const uint8_t * block) noexcept {
+        size_t blk = static_cast<uint32_t>(block[0])
+                    | (static_cast<uint32_t>(block[1]) << 8)
+                    | (static_cast<uint32_t>(block[2]) << 16)
+                    | (static_cast<uint32_t>(block[3]) << 24);
+        if constexpr (block_size == 8) {
+            blk = blk | (static_cast<uint64_t>(block[4]) << 32)
+                      | (static_cast<uint64_t>(block[5]) << 40)
+                      | (static_cast<uint64_t>(block[6]) << 48)
+                      | (static_cast<uint64_t>(block[7]) << 56);
+        }
+        state ^= blk;
+        state *= prime;
+    }
+};
+
 } // namespace jinja
index d2ed8242699c01fda4ca8391ba34cbdf71be2dc5..2d77068143587dc322a71defbe140b8b0153868e 100644 (file)
@@ -163,7 +163,7 @@ static value selectattr(const func_args & args) {
     args.ensure_vals<value_array, value_string, value_string, value_string>(true, true, false, false);
 
     auto arr = args.get_pos(0)->as_array();
-    auto attr_name = args.get_pos(1)->as_string().str();
+    auto attribute = args.get_pos(1);
     auto out = mk_val<value_array>();
     value val_default = mk_val<value_undefined>();
 
@@ -173,7 +173,7 @@ static value selectattr(const func_args & args) {
             if (!is_val<value_object>(item)) {
                 throw raised_exception("selectattr: item is not an object");
             }
-            value attr_val = item->at(attr_name, val_default);
+            value attr_val = item->at(attribute, val_default);
             bool is_selected = attr_val->as_bool();
             if constexpr (is_reject) is_selected = !is_selected;
             if (is_selected) out->push_back(item);
@@ -217,7 +217,7 @@ static value selectattr(const func_args & args) {
             if (!is_val<value_object>(item)) {
                 throw raised_exception("selectattr: item is not an object");
             }
-            value attr_val = item->at(attr_name, val_default);
+            value attr_val = item->at(attribute, val_default);
             func_args test_args(args.ctx);
             test_args.push_back(attr_val); // attribute value
             test_args.push_back(extra_arg); // extra argument
@@ -741,6 +741,7 @@ const func_builtins & value_array_t::get_builtins() const {
             args.ensure_count(1, 4);
             args.ensure_vals<value_array, value_int, value_int, value_int>(true, true, false, false);
 
+            auto val  = args.get_pos(0);
             auto arg0 = args.get_pos(1);
             auto arg1 = args.get_pos(2, mk_val<value_undefined>());
             auto arg2 = args.get_pos(3, mk_val<value_undefined>());
@@ -762,10 +763,8 @@ const func_builtins & value_array_t::get_builtins() const {
             if (step == 0) {
                 throw raised_exception("slice step cannot be zero");
             }
-            auto arr = slice(args.get_pos(0)->as_array(), start, stop, step);
-            auto res = mk_val<value_array>();
-            res->val_arr = std::move(arr);
-            return res;
+            auto arr = slice(val->as_array(), start, stop, step);
+            return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
         }},
         {"selectattr", selectattr<false>},
         {"select", selectattr<false>},
@@ -785,15 +784,14 @@ const func_builtins & value_array_t::get_builtins() const {
             }
             const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
             const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
-            const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
             std::string result;
             for (size_t i = 0; i < arr.size(); ++i) {
                 value val_arr = arr[i];
                 if (!attribute->is_undefined()) {
                     if (attr_is_int && is_val<value_array>(val_arr)) {
                         val_arr = val_arr->at(attr_int);
-                    } else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(val_arr)) {
-                        val_arr = val_arr->at(attr_name);
+                    } else if (!attr_is_int && is_val<value_object>(val_arr)) {
+                        val_arr = val_arr->at(attribute);
                     }
                 }
                 if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
@@ -808,9 +806,7 @@ const func_builtins & value_array_t::get_builtins() const {
         }},
         {"string", [](const func_args & args) -> value {
             args.ensure_vals<value_array>();
-            auto str = mk_val<value_string>();
-            gather_string_parts_recursive(args.get_pos(0), str);
-            return str;
+            return mk_val<value_string>(args.get_pos(0)->as_string());
         }},
         {"tojson", tojson},
         {"map", [](const func_args & args) -> value {
@@ -821,26 +817,26 @@ const func_builtins & value_array_t::get_builtins() const {
             if (!is_val<value_kwarg>(args.get_args().at(1))) {
                 throw not_implemented_exception("map: filter-mapping not implemented");
             }
+            value val       = args.get_pos(0);
             value attribute = args.get_kwarg_or_pos("attribute", 1);
             const bool attr_is_int = is_val<value_int>(attribute);
             if (!is_val<value_string>(attribute) && !attr_is_int) {
                 throw raised_exception("map: attribute must be string or integer");
             }
             const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
-            const std::string attr_name = attribute->as_string().str();
             value default_val = args.get_kwarg("default", mk_val<value_undefined>());
             auto out = mk_val<value_array>();
-            auto arr = args.get_pos(0)->as_array();
+            auto arr = val->as_array();
             for (const auto & item : arr) {
                 value attr_val;
                 if (attr_is_int) {
                     attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
                 } else {
-                    attr_val = is_val<value_object>(item) ? item->at(attr_name, default_val) : default_val;
+                    attr_val = is_val<value_object>(item) ? item->at(attribute, default_val) : default_val;
                 }
                 out->push_back(attr_val);
             }
-            return out;
+            return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(out->as_array())) : out;
         }},
         {"append", [](const func_args & args) -> value {
             args.ensure_count(2);
@@ -867,6 +863,7 @@ const func_builtins & value_array_t::get_builtins() const {
             if (!is_val<value_array>(args.get_pos(0))) {
                 throw raised_exception("sort: first argument must be an array");
             }
+            value val         = args.get_pos(0);
             value val_reverse = args.get_kwarg_or_pos("reverse",        1);
             value val_case    = args.get_kwarg_or_pos("case_sensitive", 2);
             value attribute   = args.get_kwarg_or_pos("attribute",      3);
@@ -875,8 +872,7 @@ const func_builtins & value_array_t::get_builtins() const {
             const bool reverse = val_reverse->as_bool(); // undefined == false
             const bool attr_is_int = is_val<value_int>(attribute);
             const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
-            const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
-            std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
+            std::vector<value> arr = val->as_array(); // copy
             std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
                 value val_a = a;
                 value val_b = b;
@@ -884,22 +880,23 @@ const func_builtins & value_array_t::get_builtins() const {
                     if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
                         val_a = a->at(attr_int);
                         val_b = b->at(attr_int);
-                    } else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(a) && is_val<value_object>(b)) {
-                        val_a = a->at(attr_name);
-                        val_b = b->at(attr_name);
+                    } else if (!attr_is_int && is_val<value_object>(a) && is_val<value_object>(b)) {
+                        val_a = a->at(attribute);
+                        val_b = b->at(attribute);
                     } else {
-                        throw raised_exception("sort: unsupported object attribute comparison");
+                        throw raised_exception("sort: unsupported object attribute comparison between " + a->type() + " and " + b->type());
                     }
                 }
                 return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
             });
-            return mk_val<value_array>(arr);
+            return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
         }},
         {"reverse", [](const func_args & args) -> value {
             args.ensure_vals<value_array>();
-            std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
+            value val = args.get_pos(0);
+            std::vector<value> arr = val->as_array(); // copy
             std::reverse(arr.begin(), arr.end());
-            return mk_val<value_array>(arr);
+            return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
         }},
         {"unique", [](const func_args &) -> value {
             throw not_implemented_exception("Array unique builtin not implemented");
@@ -930,7 +927,7 @@ const func_builtins & value_object_t::get_builtins() const {
                 default_val = args.get_pos(2);
             }
             const value obj = args.get_pos(0);
-            std::string key = args.get_pos(1)->as_string().str();
+            const value key = args.get_pos(1);
             return obj->at(key, default_val);
         }},
         {"keys", [](const func_args & args) -> value {
@@ -938,7 +935,7 @@ const func_builtins & value_object_t::get_builtins() const {
             const auto & obj = args.get_pos(0)->as_ordered_object();
             auto result = mk_val<value_array>();
             for (const auto & pair : obj) {
-                result->push_back(mk_val<value_string>(pair.first));
+                result->push_back(pair.first);
             }
             return result;
         }},
@@ -956,15 +953,16 @@ const func_builtins & value_object_t::get_builtins() const {
             const auto & obj = args.get_pos(0)->as_ordered_object();
             auto result = mk_val<value_array>();
             for (const auto & pair : obj) {
-                auto item = mk_val<value_array>();
-                item->push_back(mk_val<value_string>(pair.first));
-                item->push_back(pair.second);
+                auto item = mk_val<value_tuple>(pair);
                 result->push_back(std::move(item));
             }
             return result;
         }},
         {"tojson", tojson},
-        {"string", tojson},
+        {"string", [](const func_args & args) -> value {
+            args.ensure_vals<value_object>();
+            return mk_val<value_string>(args.get_pos(0)->as_string());
+        }},
         {"length", [](const func_args & args) -> value {
             args.ensure_vals<value_object>();
             const auto & obj = args.get_pos(0)->as_ordered_object();
@@ -985,11 +983,11 @@ const func_builtins & value_object_t::get_builtins() const {
             const bool reverse = val_reverse->as_bool(); // undefined == false
             const bool by_value = is_val<value_string>(val_by) && val_by->as_string().str() == "value" ? true : false;
             auto result = mk_val<value_object>(val_input); // copy
-            std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) {
+            std::sort(result->val_obj.begin(), result->val_obj.end(), [&](const auto & a, const auto & b) {
                 if (by_value) {
                     return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt);
                 } else {
-                    return reverse ? a.first > b.first : a.first < b.first;
+                    return value_compare(a.first, b.first, reverse ? value_compare_op::gt : value_compare_op::lt);
                 }
             });
             return result;
@@ -1134,6 +1132,8 @@ void global_from_json(context & ctx, const nlohmann::ordered_json & json_obj, bo
     }
 }
 
+// recursively convert value to JSON string
+// TODO: avoid circular references
 static void value_to_json_internal(std::ostringstream & oss, const value & val, int curr_lvl, int indent, const std::string_view item_sep, const std::string_view key_sep) {
     auto indent_str = [indent, curr_lvl]() -> std::string {
         return (indent > 0) ? std::string(curr_lvl * indent, ' ') : "";
@@ -1196,7 +1196,8 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val,
             size_t i = 0;
             for (const auto & pair : obj) {
                 oss << indent_str() << (indent > 0 ? std::string(indent, ' ') : "");
-                oss << "\"" << pair.first << "\"" << key_sep;
+                value_to_json_internal(oss, mk_val<value_string>(pair.first->as_string().str()), curr_lvl + 1, indent, item_sep, key_sep);
+                oss << key_sep;
                 value_to_json_internal(oss, pair.second, curr_lvl + 1, indent, item_sep, key_sep);
                 if (i < obj.size() - 1) {
                     oss << item_sep;
@@ -1219,4 +1220,19 @@ std::string value_to_json(const value & val, int indent, const std::string_view
     return oss.str();
 }
 
+// TODO: avoid circular references
+std::string value_to_string_repr(const value & val) {
+    if (is_val<value_string>(val)) {
+        const std::string val_str = val->as_string().str();
+
+        if (val_str.find('\'') != std::string::npos) {
+            return value_to_json(val);
+        } else {
+            return "'" + val_str + "'";
+        }
+    } else {
+        return val->as_repr();
+    }
+}
+
 } // namespace jinja
index ccb05c6fd413befbc3eb3d6a4768e06e324299b4..a2f92d2c69da89c145c536bf93448c24e6badaef 100644 (file)
@@ -1,8 +1,10 @@
 #pragma once
 
 #include "string.h"
+#include "utils.h"
 
 #include <algorithm>
+#include <cmath>
 #include <cstdint>
 #include <functional>
 #include <map>
@@ -93,7 +95,8 @@ void global_from_json(context & ctx, const T_JSON & json_obj, bool mark_input);
 
 struct func_args; // function argument values
 
-using func_handler = std::function<value(const func_args &)>;
+using func_hptr = value(const func_args &);
+using func_handler = std::function<func_hptr>;
 using func_builtins = std::map<std::string, func_handler>;
 
 enum value_compare_op { eq, ge, gt, lt, ne };
@@ -103,28 +106,9 @@ struct value_t {
     int64_t val_int;
     double val_flt;
     string val_str;
-    bool val_bool;
 
     std::vector<value> val_arr;
-
-    struct map {
-        // once set to true, all keys must be numeric
-        // caveat: we only allow either all numeric keys or all non-numeric keys
-        // for now, this only applied to for_statement in case of iterating over object keys/items
-        bool is_key_numeric = false;
-        std::map<std::string, value> unordered;
-        std::vector<std::pair<std::string, value>> ordered;
-        void insert(const std::string & key, const value & val) {
-            if (unordered.find(key) != unordered.end()) {
-                // if key exists, remove from ordered list
-                ordered.erase(std::remove_if(ordered.begin(), ordered.end(),
-                    [&](const std::pair<std::string, value> & p) { return p.first == key; }),
-                    ordered.end());
-            }
-            unordered[key] = val;
-            ordered.push_back({key, val});
-        }
-    } val_obj;
+    std::vector<std::pair<value, value>> val_obj;
 
     func_handler val_func;
 
@@ -139,6 +123,7 @@ struct value_t {
     value_t(const value_t &) = default;
     virtual ~value_t() = default;
 
+    // Note: only for debugging and error reporting purposes
     virtual std::string type() const { return ""; }
 
     virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
@@ -146,7 +131,7 @@ struct value_t {
     virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
     virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
     virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
-    virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
+    virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
     virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
     virtual bool is_none() const { return false; }
     virtual bool is_undefined() const { return false; }
@@ -154,43 +139,66 @@ struct value_t {
         throw std::runtime_error("No builtins available for type " + type());
     }
 
-    virtual bool has_key(const std::string & key) {
-        return val_obj.unordered.find(key) != val_obj.unordered.end();
-    }
-    virtual value & at(const std::string & key, value & default_val) {
-        auto it = val_obj.unordered.find(key);
-        if (it == val_obj.unordered.end()) {
-            return default_val;
-        }
-        return val_obj.unordered.at(key);
-    }
-    virtual value & at(const std::string & key) {
-        auto it = val_obj.unordered.find(key);
-        if (it == val_obj.unordered.end()) {
-            throw std::runtime_error("Key '" + key + "' not found in value of type " + type());
-        }
-        return val_obj.unordered.at(key);
+    virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); }
+    virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); }
+    virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
+    virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
+    virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
+    virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
+    virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); }
+    virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); }
+
+    virtual bool is_numeric() const { return false; }
+    virtual bool is_hashable() const { return false; }
+    virtual bool is_immutable() const { return true; }
+    virtual hasher unique_hash() const noexcept = 0;
+    // TODO: C++20 <=> operator
+    // NOTE: We are treating == as equivalent (for normal comparisons) and != as strict nonequal (for strict (is) comparisons)
+    virtual bool operator==(const value_t & other) const { return equivalent(other); }
+    virtual bool operator!=(const value_t & other) const { return nonequal(other); }
+
+    // Note: only for debugging purposes
+    virtual std::string as_repr() const { return as_string().str(); }
+
+protected:
+    virtual bool equivalent(const value_t &) const = 0;
+    virtual bool nonequal(const value_t & other) const { return !equivalent(other); }
+};
+
+//
+// utils
+//
+
+const func_builtins & global_builtins();
+
+std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
+
+// Note: only used for debugging purposes
+std::string value_to_string_repr(const value & val);
+
+struct not_implemented_exception : public std::runtime_error {
+    not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
+};
+
+struct value_hasher {
+    size_t operator()(const value & val) const noexcept {
+        return val->unique_hash().digest();
     }
-    virtual value & at(int64_t index, value & default_val) {
-        if (index < 0) {
-            index += val_arr.size();
-        }
-        if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
-            return default_val;
-        }
-        return val_arr[index];
+};
+
+struct value_equivalence {
+    bool operator()(const value & lhs, const value & rhs) const {
+        return *lhs == *rhs;
     }
-    virtual value & at(int64_t index) {
-        if (index < 0) {
-            index += val_arr.size();
-        }
-        if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
-            throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
-        }
-        return val_arr[index];
+    bool operator()(const std::pair<value, value> & lhs, const std::pair<value, value> & rhs) const {
+        return *(lhs.first) == *(rhs.first) && *(lhs.second) == *(rhs.second);
     }
+};
 
-    virtual std::string as_repr() const { return as_string().str(); }
+struct value_equality {
+    bool operator()(const value & lhs, const value & rhs) const {
+        return !(*lhs != *rhs);
+    }
 };
 
 //
@@ -198,24 +206,49 @@ struct value_t {
 //
 
 struct value_int_t : public value_t {
-    value_int_t(int64_t v) { val_int = v; }
+    value_int_t(int64_t v) {
+        val_int = v;
+        val_flt = static_cast<double>(v);
+        if (static_cast<int64_t>(val_flt) != v) {
+            val_flt = v < 0 ? -INFINITY : INFINITY;
+        }
+    }
     virtual std::string type() const override { return "Integer"; }
     virtual int64_t as_int() const override { return val_int; }
-    virtual double as_float() const override { return static_cast<double>(val_int); }
+    virtual double as_float() const override { return val_flt; }
     virtual string as_string() const override { return std::to_string(val_int); }
     virtual bool as_bool() const override {
         return val_int != 0;
     }
     virtual const func_builtins & get_builtins() const override;
+    virtual bool is_numeric() const override { return true; }
+    virtual bool is_hashable() const override { return true; }
+    virtual hasher unique_hash() const noexcept override {
+        return hasher(typeid(*this))
+            .update(&val_int, sizeof(val_int))
+            .update(&val_flt, sizeof(val_flt));
+    }
+protected:
+    virtual bool equivalent(const value_t & other) const override {
+        return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
+    }
+    virtual bool nonequal(const value_t & other) const override {
+        return !(typeid(*this) == typeid(other) && val_int == other.val_int);
+    }
 };
 using value_int = std::shared_ptr<value_int_t>;
 
 
 struct value_float_t : public value_t {
-    value_float_t(double v) { val_flt = v; }
+    value val;
+    value_float_t(double v) {
+        val_flt = v;
+        val_int = std::isfinite(v) ? static_cast<int64_t>(v) : 0;
+        val = mk_val<value_int>(val_int);
+    }
     virtual std::string type() const override { return "Float"; }
     virtual double as_float() const override { return val_flt; }
-    virtual int64_t as_int() const override { return static_cast<int64_t>(val_flt); }
+    virtual int64_t as_int() const override { return val_int; }
     virtual string as_string() const override {
         std::string out = std::to_string(val_flt);
         out.erase(out.find_last_not_of('0') + 1, std::string::npos); // remove trailing zeros
@@ -226,6 +259,24 @@ struct value_float_t : public value_t {
         return val_flt != 0.0;
     }
     virtual const func_builtins & get_builtins() const override;
+    virtual bool is_numeric() const override { return true; }
+    virtual bool is_hashable() const override { return true; }
+    virtual hasher unique_hash() const noexcept override {
+        if (static_cast<double>(val_int) == val_flt) {
+            return val->unique_hash();
+        } else {
+            return hasher(typeid(*this))
+                .update(&val_int, sizeof(val_int))
+                .update(&val_flt, sizeof(val_flt));
+        }
+    }
+protected:
+    virtual bool equivalent(const value_t & other) const override {
+        return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
+    }
+    virtual bool nonequal(const value_t & other) const override {
+        return !(typeid(*this) == typeid(other) && val_flt == other.val_flt);
+    }
 };
 using value_float = std::shared_ptr<value_float_t>;
 
@@ -247,19 +298,49 @@ struct value_string_t : public value_t {
         return val_str.length() > 0;
     }
     virtual const func_builtins & get_builtins() const override;
+    virtual bool is_hashable() const override { return true; }
+    virtual hasher unique_hash() const noexcept override {
+        const auto type_hash = typeid(*this).hash_code();
+        auto hash = hasher();
+        hash.update(&type_hash, sizeof(type_hash));
+        val_str.hash_update(hash);
+        return hash;
+    }
     void mark_input() {
         val_str.mark_input();
     }
+protected:
+    virtual bool equivalent(const value_t & other) const override {
+        return typeid(*this) == typeid(other) && val_str.str() == other.val_str.str();
+    }
 };
 using value_string = std::shared_ptr<value_string_t>;
 
 
 struct value_bool_t : public value_t {
-    value_bool_t(bool v) { val_bool = v; }
+    value val;
+    value_bool_t(bool v) {
+        val_int = static_cast<int64_t>(v);
+        val_flt = static_cast<double>(v);
+        val = mk_val<value_int>(val_int);
+    }
     virtual std::string type() const override { return "Boolean"; }
-    virtual bool as_bool() const override { return val_bool; }
-    virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); }
+    virtual int64_t as_int() const override { return val_int; }
+    virtual bool as_bool() const override { return val_int; }
+    virtual string as_string() const override { return std::string(val_int ? "True" : "False"); }
     virtual const func_builtins & get_builtins() const override;
+    virtual bool is_numeric() const override { return true; }
+    virtual bool is_hashable() const override { return true; }
+    virtual hasher unique_hash() const noexcept override {
+        return val->unique_hash();
+    }
+protected:
+    virtual bool equivalent(const value_t & other) const override {
+        return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
+    }
+    virtual bool nonequal(const value_t & other) const override {
+        return !(typeid(*this) == typeid(other) && val_int == other.val_int);
+    }
 };
 using value_bool = std::shared_ptr<value_bool_t>;
 
@@ -269,13 +350,34 @@ struct value_array_t : public value_t {
     value_array_t(value & v) {
         val_arr = v->val_arr;
     }
+    value_array_t(std::vector<value> && arr) {
+        val_arr = arr;
+    }
     value_array_t(const std::vector<value> & arr) {
         val_arr = arr;
     }
-    void reverse() { std::reverse(val_arr.begin(), val_arr.end()); }
-    void push_back(const value & val) { val_arr.push_back(val); }
-    void push_back(value && val) { val_arr.push_back(std::move(val)); }
+    void reverse() {
+        if (is_immutable()) {
+            throw std::runtime_error("Attempting to modify immutable type");
+        }
+        std::reverse(val_arr.begin(), val_arr.end());
+    }
+    void push_back(const value & val) {
+        if (is_immutable()) {
+            throw std::runtime_error("Attempting to modify immutable type");
+        }
+        val_arr.push_back(val);
+    }
+    void push_back(value && val) {
+        if (is_immutable()) {
+            throw std::runtime_error("Attempting to modify immutable type");
+        }
+        val_arr.push_back(std::move(val));
+    }
     value pop_at(int64_t index) {
+        if (is_immutable()) {
+            throw std::runtime_error("Attempting to modify immutable type");
+        }
         if (index < 0) {
             index = static_cast<int64_t>(val_arr.size()) + index;
         }
@@ -287,64 +389,225 @@ struct value_array_t : public value_t {
         return val;
     }
     virtual std::string type() const override { return "Array"; }
+    virtual bool is_immutable() const override { return false; }
     virtual const std::vector<value> & as_array() const override { return val_arr; }
     virtual string as_string() const override {
+        const bool immutable = is_immutable();
         std::ostringstream ss;
-        ss << "[";
+        ss << (immutable ? "(" : "[");
         for (size_t i = 0; i < val_arr.size(); i++) {
             if (i > 0) ss << ", ";
-            ss << val_arr.at(i)->as_repr();
+            value val = val_arr.at(i);
+            ss << value_to_string_repr(val);
         }
-        ss << "]";
+        if (immutable && val_arr.size() == 1) {
+            ss << ",";
+        }
+        ss << (immutable ? ")" : "]");
         return ss.str();
     }
     virtual bool as_bool() const override {
         return !val_arr.empty();
     }
+    virtual value & at(int64_t index, value & default_val) override {
+        if (index < 0) {
+            index += val_arr.size();
+        }
+        if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
+            return default_val;
+        }
+        return val_arr[index];
+    }
+    virtual value & at(int64_t index) override {
+        if (index < 0) {
+            index += val_arr.size();
+        }
+        if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
+            throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
+        }
+        return val_arr[index];
+    }
     virtual const func_builtins & get_builtins() const override;
+    virtual bool is_hashable() const override {
+        if (std::all_of(val_arr.begin(), val_arr.end(), [&](auto & val) -> bool {
+            return val->is_immutable() && val->is_hashable();
+        })) {
+            return true;
+        }
+        return false;
+    }
+    virtual hasher unique_hash() const noexcept override {
+        auto hash = hasher(typeid(*this));
+        for (const auto & val : val_arr) {
+            // must use digest to prevent problems from "concatenation" property of hasher
+            // for ex. hash of [ "ab", "c" ] should be different from [ "a", "bc" ]
+            const size_t val_hash = val->unique_hash().digest();
+            hash.update(&val_hash, sizeof(size_t));
+        }
+        return hash;
+    }
+protected:
+    virtual bool equivalent(const value_t & other) const override {
+        return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
+    }
 };
 using value_array = std::shared_ptr<value_array_t>;
 
 
+struct value_tuple_t : public value_array_t {
+    value_tuple_t(value & v) {
+        val_arr = v->val_arr;
+    }
+    value_tuple_t(std::vector<value> && arr) {
+        val_arr = arr;
+    }
+    value_tuple_t(const std::vector<value> & arr) {
+        val_arr = arr;
+    }
+    value_tuple_t(const std::pair<value, value> & pair) {
+        val_arr.push_back(pair.first);
+        val_arr.push_back(pair.second);
+    }
+    virtual std::string type() const override { return "Tuple"; }
+    virtual bool is_immutable() const override { return true; }
+};
+using value_tuple = std::shared_ptr<value_tuple_t>;
+
+
 struct value_object_t : public value_t {
+    std::unordered_map<value, value, value_hasher, value_equivalence> unordered;
     bool has_builtins = true; // context and loop objects do not have builtins
     value_object_t() = default;
     value_object_t(value & v) {
         val_obj = v->val_obj;
+        for (const auto & pair : val_obj) {
+            unordered[pair.first] = pair.second;
+        }
     }
-    value_object_t(const std::map<std::string, value> & obj) {
+    value_object_t(const std::map<value, value> & obj) {
         for (const auto & pair : obj) {
-            val_obj.insert(pair.first, pair.second);
+            insert(pair.first, pair.second);
         }
     }
-    value_object_t(const std::vector<std::pair<std::string, value>> & obj) {
+    value_object_t(const std::vector<std::pair<value, value>> & obj) {
         for (const auto & pair : obj) {
-            val_obj.insert(pair.first, pair.second);
+            insert(pair.first, pair.second);
         }
     }
     void insert(const std::string & key, const value & val) {
-        val_obj.insert(key, val);
+        insert(mk_val<value_string>(key), val);
     }
     virtual std::string type() const override { return "Object"; }
-    virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const override { return val_obj.ordered; }
+    virtual bool is_immutable() const override { return false; }
+    virtual const std::vector<std::pair<value, value>> & as_ordered_object() const override { return val_obj; }
+    virtual string as_string() const override {
+        std::ostringstream ss;
+        ss << "{";
+        for (size_t i = 0; i < val_obj.size(); i++) {
+            if (i > 0) ss << ", ";
+            auto & [key, val] = val_obj.at(i);
+            ss << value_to_string_repr(key) << ": " << value_to_string_repr(val);
+        }
+        ss << "}";
+        return ss.str();
+    }
     virtual bool as_bool() const override {
-        return !val_obj.unordered.empty();
+        return !unordered.empty();
+    }
+    virtual bool has_key(const value & key) override {
+        if (!key->is_immutable() || !key->is_hashable()) {
+            throw std::runtime_error("Object key of unhashable type: " + key->type());
+        }
+        return unordered.find(key) != unordered.end();
+    }
+    virtual void insert(const value & key, const value & val) override {
+        bool replaced = false;
+        if (is_immutable()) {
+            throw std::runtime_error("Attempting to modify immutable type");
+        }
+        if (has_key(key)) {
+            // if key exists, replace value in ordered list instead of appending
+            for (auto & pair : val_obj) {
+                if (*(pair.first) == *key) {
+                    pair.second = val;
+                    replaced = true;
+                    break;
+                }
+            }
+        }
+        unordered[key] = val;
+        if (!replaced) {
+            val_obj.push_back({key, val});
+        }
+    }
+    virtual value & at(const value & key, value & default_val) override {
+        if (!has_key(key)) {
+            return default_val;
+        }
+        return unordered.at(key);
+    }
+    virtual value & at(const value & key) override {
+        if (!has_key(key)) {
+            throw std::runtime_error("Key '" + key->as_string().str() + "' not found in value of type " + type());
+        }
+        return unordered.at(key);
+    }
+    virtual value & at(const std::string & key, value & default_val) override {
+        value key_val = mk_val<value_string>(key);
+        return at(key_val, default_val);
+    }
+    virtual value & at(const std::string & key) override {
+        value key_val = mk_val<value_string>(key);
+        return at(key_val);
     }
     virtual const func_builtins & get_builtins() const override;
+    virtual bool is_hashable() const override {
+        if (std::all_of(val_obj.begin(), val_obj.end(), [&](auto & pair) -> bool {
+            const auto & val = pair.second;
+            return val->is_immutable() && val->is_hashable();
+        })) {
+            return true;
+        }
+        return false;
+    }
+    virtual hasher unique_hash() const noexcept override {
+        auto hash = hasher(typeid(*this));
+        for (const auto & [key, val] : val_obj) {
+            // must use digest to prevent problems from "concatenation" property of hasher
+            // for ex. hash of key="ab", value="c" should be different from key="a", value="bc"
+            const size_t key_hash = key->unique_hash().digest();
+            const size_t val_hash = val->unique_hash().digest();
+            hash.update(&key_hash, sizeof(key_hash));
+            hash.update(&val_hash, sizeof(val_hash));
+        }
+        return hash;
+    }
+protected:
+    virtual bool equivalent(const value_t & other) const override {
+        return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
+    }
 };
 using value_object = std::shared_ptr<value_object_t>;
 
 //
-// null and undefined types
+// none and undefined types
 //
 
 struct value_none_t : public value_t {
     virtual std::string type() const override { return "None"; }
     virtual bool is_none() const override { return true; }
     virtual bool as_bool() const override { return false; }
-    virtual string as_string() const override { return string("None"); }
+    virtual string as_string() const override { return string(type()); }
     virtual std::string as_repr() const override { return type(); }
     virtual const func_builtins & get_builtins() const override;
+    virtual bool is_hashable() const override { return true; }
+    virtual hasher unique_hash() const noexcept override {
+        return hasher(typeid(*this));
+    }
+protected:
+    virtual bool equivalent(const value_t & other) const override {
+        return typeid(*this) == typeid(other);
+    }
 };
 using value_none = std::shared_ptr<value_none_t>;
 
@@ -356,6 +619,13 @@ struct value_undefined_t : public value_t {
     virtual bool as_bool() const override { return false; }
     virtual std::string as_repr() const override { return type(); }
     virtual const func_builtins & get_builtins() const override;
+    virtual hasher unique_hash() const noexcept override {
+        return hasher(typeid(*this));
+    }
+protected:
+    virtual bool equivalent(const value_t & other) const override {
+        return is_undefined() == other.is_undefined();
+    }
 };
 using value_undefined = std::shared_ptr<value_undefined_t>;
 
@@ -436,7 +706,23 @@ struct value_func_t : public value_t {
         return val_func(new_args);
     }
     virtual std::string type() const override { return "Function"; }
-    virtual std::string as_repr() const override { return type(); }
+    virtual std::string as_repr() const override { return type() + "<" + name + ">(" + (arg0 ? arg0->as_repr() : "") + ")"; }
+    virtual bool is_hashable() const override { return false; }
+    virtual hasher unique_hash() const noexcept override {
+        // Note: this is unused for now, we don't support function as object keys
+        // use function pointer as unique identifier
+        const auto target = val_func.target<func_hptr>();
+        return hasher(typeid(*this)).update(&target, sizeof(target));
+    }
+protected:
+    virtual bool equivalent(const value_t & other) const override {
+        // Note: this is unused for now, we don't support function as object keys
+        // compare function pointers
+        // (val_func == other.val_func does not work as std::function::operator== is only used for nullptr check)
+        const auto target_this  = this->val_func.target<func_hptr>();
+        const auto target_other = other.val_func.target<func_hptr>();
+        return typeid(*this) == typeid(other) && target_this == target_other;
+    }
 };
 using value_func = std::shared_ptr<value_func_t>;
 
@@ -447,18 +733,21 @@ struct value_kwarg_t : public value_t {
     value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {}
     virtual std::string type() const override { return "KwArg"; }
     virtual std::string as_repr() const override { return type(); }
+    virtual bool is_hashable() const override { return true; }
+    virtual hasher unique_hash() const noexcept override {
+        const auto type_hash = typeid(*this).hash_code();
+        auto hash = val->unique_hash();
+        hash.update(&type_hash, sizeof(type_hash))
+            .update(key.data(), key.size());
+        return hash;
+    }
+protected:
+    virtual bool equivalent(const value_t & other) const override {
+        const value_kwarg_t & other_val = static_cast<const value_kwarg_t &>(other);
+        return typeid(*this) == typeid(other) && key == other_val.key && val == other_val.val;
+    }
 };
 using value_kwarg = std::shared_ptr<value_kwarg_t>;
 
 
-// utils
-
-const func_builtins & global_builtins();
-std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
-
-struct not_implemented_exception : public std::runtime_error {
-    not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
-};
-
-
 } // namespace jinja
index e142900723784a35a9037c42918dbadaa7cd5d04..d2a1437ca44a13aa43746a96515d9a72b05b9346 100644 (file)
@@ -481,7 +481,7 @@ int main_automated_tests(void) {
             /* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)",
             /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n    {%- set system_message = messages[0][\"content\"] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n    {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n        {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n            {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n        {%- endif %}\n        {%- set ns.index = ns.index + 1 %}\n    {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if message[\"role\"] == \"user\" %}\n        {%- if tools is not none and (message == user_messages[-1]) %}\n            {{- \"[AVAILABLE_TOOLS] [\" }}\n            {%- for tool in tools %}\n                {%- set tool = tool.function %}\n                {{- '{\"type\": \"function\", \"function\": {' }}\n                {%- for key, val in tool.items() if key != \"return\" %}\n                    {%- if val is string %}\n                        {{- '\"' + key + '\": \"' + val + '\"' }}\n                    {%- else %}\n                        {{- '\"' + key + '\": ' + val|tojson }}\n                    {%- endif %}\n                    {%- if not loop.last %}\n                        {{- \", \" }}\n                    {%- endif %}\n                {%- endfor %}\n                {{- \"}}\" }}\n                {%- if not loop.last %}\n                    {{- \", \" }}\n                {%- else %}\n                    {{- \"]\" }}\n                {%- endif %}\n            {%- endfor %}\n            {{- \"[/AVAILABLE_TOOLS]\" }}\n            {%- endif %}\n        {%- if loop.last and system_message is defined %}\n            {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n        {%- else %}\n            {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n        {%- endif %}\n    {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n        {{- \"[TOOL_CALLS] [\" }}\n        {%- for tool_call in message.tool_calls %}\n            {%- set out = tool_call.function|tojson %}\n            {{- out[:-1] }}\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n                {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n            {%- endif %}\n            {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n            {%- if not loop.last %}\n                {{- \", \" }}\n            {%- else %}\n                {{- \"]\" + eos_token }}\n            {%- endif %}\n        {%- endfor %}\n    {%- elif message[\"role\"] == \"assistant\" %}\n        {{- \" \" + message[\"content\"]|trim + eos_token}}\n    {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n        {%- if message.content is defined and message.content.content is defined %}\n            {%- set content = message.content.content %}\n        {%- else %}\n            {%- set content = message.content %}\n        {%- endif %}\n        {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n            {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n        {%- endif %}\n        {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n    {%- else %}\n        {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n    {%- endif %}\n{%- endfor %}\n",
             /* .expected_output= */       "[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
-            /* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] You are a helpful assistant\n\nAnother question[/INST]",
+            /* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[AVAILABLE_TOOLS] [[/AVAILABLE_TOOLS][INST] You are a helpful assistant\n\nAnother question[/INST]",
             /* .bos_token= */ "",
             /* .eos_token= */ "</s>",
         },
@@ -489,7 +489,7 @@ int main_automated_tests(void) {
             /* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)",
             /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n    {%- set system_message = messages[0][\"content\"] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n    {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n        {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n            {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n        {%- endif %}\n        {%- set ns.index = ns.index + 1 %}\n    {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if message[\"role\"] == \"user\" %}\n        {%- if tools is not none and (message == user_messages[-1]) %}\n            {{- \"[AVAILABLE_TOOLS][\" }}\n            {%- for tool in tools %}\n                {%- set tool = tool.function %}\n                {{- '{\"type\": \"function\", \"function\": {' }}\n                {%- for key, val in tool.items() if key != \"return\" %}\n                    {%- if val is string %}\n                        {{- '\"' + key + '\": \"' + val + '\"' }}\n                    {%- else %}\n                        {{- '\"' + key + '\": ' + val|tojson }}\n                    {%- endif %}\n                    {%- if not loop.last %}\n                        {{- \", \" }}\n                    {%- endif %}\n                {%- endfor %}\n                {{- \"}}\" }}\n                {%- if not loop.last %}\n                    {{- \", \" }}\n                {%- else %}\n                    {{- \"]\" }}\n                {%- endif %}\n            {%- endfor %}\n            {{- \"[/AVAILABLE_TOOLS]\" }}\n            {%- endif %}\n        {%- if loop.last and system_message is defined %}\n            {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n        {%- else %}\n            {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n        {%- endif %}\n    {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n        {{- \"[TOOL_CALLS][\" }}\n        {%- for tool_call in message.tool_calls %}\n            {%- set out = tool_call.function|tojson %}\n            {{- out[:-1] }}\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n                {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n            {%- endif %}\n            {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n            {%- if not loop.last %}\n                {{- \", \" }}\n            {%- else %}\n                {{- \"]\" + eos_token }}\n            {%- endif %}\n        {%- endfor %}\n    {%- elif message[\"role\"] == \"assistant\" %}\n        {{- message[\"content\"] + eos_token}}\n    {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n        {%- if message.content is defined and message.content.content is defined %}\n            {%- set content = message.content.content %}\n        {%- else %}\n            {%- set content = message.content %}\n        {%- endif %}\n        {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n            {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n        {%- endif %}\n        {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n    {%- else %}\n        {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n    {%- endif %}\n{%- endfor %}\n",
             /* .expected_output= */       "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST]   I am an assistant   </s>[INST]Another question[/INST]",
-            /* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there</s>[INST]Who are you[/INST]   I am an assistant   </s>[INST]You are a helpful assistant\n\nAnother question[/INST]",
+            /* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there</s>[INST]Who are you[/INST]   I am an assistant   </s>[AVAILABLE_TOOLS][[/AVAILABLE_TOOLS][INST]You are a helpful assistant\n\nAnother question[/INST]",
             /* .bos_token= */ "",
             /* .eos_token= */ "</s>",
         },
index 54d3a0923bdf7bdab85ca7c52d33585147348915..7c6eeb311ce49881eeba839a8dee0191f99e4ad7 100644 (file)
@@ -9,6 +9,7 @@
 #include "jinja/runtime.h"
 #include "jinja/parser.h"
 #include "jinja/lexer.h"
+#include "jinja/utils.h"
 
 #include "testing.h"
 
@@ -30,6 +31,7 @@ static void test_tests(testing & t);
 static void test_string_methods(testing & t);
 static void test_array_methods(testing & t);
 static void test_object_methods(testing & t);
+static void test_hasher(testing & t);
 static void test_fuzzing(testing & t);
 
 static bool g_python_mode = false;
@@ -67,6 +69,7 @@ int main(int argc, char *argv[]) {
     t.test("array methods", test_array_methods);
     t.test("object methods", test_object_methods);
     if (!g_python_mode) {
+        t.test("hasher", test_hasher);
         t.test("fuzzing", test_fuzzing);
     }
 
@@ -156,6 +159,18 @@ static void test_conditionals(testing & t) {
         "big"
     );
 
+    test_template(t, "object comparison",
+        "{% if {0: 1, none: 2, 1.0: 3, '0': 4, true: 5} == {false: 1, none: 2, 1: 5, '0': 4} %}equal{% endif %}",
+        json::object(),
+        "equal"
+    );
+
+    test_template(t, "array comparison",
+        "{% if [0, 1.0, false] == [false, 1, 0.0] %}equal{% endif %}",
+        json::object(),
+        "equal"
+    );
+
     test_template(t, "logical and",
         "{% if a and b %}both{% endif %}",
         {{"a", true}, {"b", true}},
@@ -358,6 +373,30 @@ static void test_expressions(testing & t) {
         "b"
     );
 
+    test_template(t, "array negative access",
+        "{{ items[-1] }}",
+        {{"items", json::array({"a", "b", "c"})}},
+        "c"
+    );
+
+    test_template(t, "array slice",
+        "{{ items[1:-1]|string }}",
+        {{"items", json::array({"a", "b", "c"})}},
+        "['b']"
+    );
+
+    test_template(t, "array slice step",
+        "{{ items[::2]|string }}",
+        {{"items", json::array({"a", "b", "c"})}},
+        "['a', 'c']"
+    );
+
+    test_template(t, "tuple slice",
+        "{{ ('a', 'b', 'c')[::-1]|string }}",
+        json::object(),
+        "('c', 'b', 'a')"
+    );
+
     test_template(t, "arithmetic",
         "{{ (a + b) * c }}",
         {{"a", 2}, {"b", 3}, {"c", 4}},
@@ -401,6 +440,36 @@ static void test_set_statement(testing & t) {
         json::object(),
         "1"
     );
+
+    test_template(t, "set dict with mixed type keys",
+        "{% set d = {0: 1, none: 2, 1.0: 3, '0': 4, (0, 0): 5, false: 6, 1: 7} %}{{ d[(0, 0)] + d[0] + d[none] + d['0'] + d[false] + d[1.0] + d[1] }}",
+        json::object(),
+        "37"
+    );
+
+    test_template(t, "print dict with mixed type keys",
+        "{% set d = {0: 1, none: 2, 1.0: 3, '0': 4, (0, 0): 5, true: 6} %}{{ d|string }}",
+        json::object(),
+        "{0: 1, None: 2, 1.0: 6, '0': 4, (0, 0): 5}"
+    );
+
+    test_template(t, "print array with mixed types",
+        "{% set d = [0, none, 1.0, '0', true, (0, 0)] %}{{ d|string }}",
+        json::object(),
+        "[0, None, 1.0, '0', True, (0, 0)]"
+    );
+
+    test_template(t, "object member assignment with mixed key types",
+        "{% set d = namespace() %}{% set d.a = 123 %}{{ d['a'] == 123 }}",
+        json::object(),
+        "True"
+    );
+
+    test_template(t, "tuple unpacking",
+        "{% set t = (1, 2, 3) %}{% set a, b, c = t %}{{ a + b + c }}",
+        json::object(),
+        "6"
+    );
 }
 
 static void test_filters(testing & t) {
@@ -1312,6 +1381,154 @@ static void test_object_methods(testing & t) {
         {{"obj", {{"a", "b"}}}},
         "True True"
     );
+
+    test_template(t, "expression as object key",
+        "{% set d = {'ab': 123} %}{{ d['a' + 'b'] == 123 }}",
+        json::object(),
+        "True"
+    );
+
+    test_template(t, "numeric as object key (template: Seed-OSS)",
+        "{% set d = {1: 'a', 2: 'b'} %}{{ d[1] == 'a' and d[2] == 'b' }}",
+        json::object(),
+        "True"
+    );
+}
+
+static void test_hasher(testing & t) {
+    static const std::vector<std::pair<size_t, size_t>> chunk_sizes = {
+        {1, 2},
+        {1, 16},
+        {8, 1},
+        {1, 1024},
+        {5, 512},
+        {16, 256},
+        {45, 122},
+        {70, 634},
+    };
+
+    static auto random_bytes = [](size_t length) -> std::string {
+        std::string data;
+        data.resize(length);
+        for (size_t i = 0; i < length; ++i) {
+            data[i] = static_cast<char>(rand() % 256);
+        }
+        return data;
+    };
+
+    t.test("state unchanged with empty input", [](testing & t) {
+        jinja::hasher hasher;
+        hasher.update("some data");
+        size_t initial_state = hasher.digest();
+        hasher.update("", 0);
+        size_t final_state = hasher.digest();
+        t.assert_true("Hasher state should remain unchanged", initial_state == final_state);
+    });
+
+    t.test("different inputs produce different hashes", [](testing & t) {
+        jinja::hasher hasher1;
+        hasher1.update("data one");
+        size_t hash1 = hasher1.digest();
+
+        jinja::hasher hasher2;
+        hasher2.update("data two");
+        size_t hash2 = hasher2.digest();
+
+        t.assert_true("Different inputs should produce different hashes", hash1 != hash2);
+    });
+
+    t.test("same inputs produce same hashes", [](testing & t) {
+        jinja::hasher hasher1;
+        hasher1.update("consistent data");
+        size_t hash1 = hasher1.digest();
+
+        jinja::hasher hasher2;
+        hasher2.update("consistent data");
+        size_t hash2 = hasher2.digest();
+
+        t.assert_true("Same inputs should produce same hashes", hash1 == hash2);
+    });
+
+    t.test("property: update(a ~ b) == update(a).update(b)", [](testing & t) {
+        for (const auto & [size1, size2] : chunk_sizes) {
+            std::string data1 = random_bytes(size1);
+            std::string data2 = random_bytes(size2);
+
+            jinja::hasher hasher1;
+            hasher1.update(data1);
+            hasher1.update(data2);
+            size_t hash1 = hasher1.digest();
+
+            jinja::hasher hasher2;
+            hasher2.update(data1 + data2);
+            size_t hash2 = hasher2.digest();
+
+            t.assert_true(
+                "Hashing in multiple updates should match single update (" + std::to_string(size1) + ", " + std::to_string(size2) + ")",
+                hash1 == hash2);
+        }
+    });
+
+    t.test("property: update(a ~ b) == update(a).update(b) with more update passes", [](testing & t) {
+        static const std::vector<size_t> sizes = {3, 732, 131, 13, 17, 256, 436, 99, 4};
+
+        jinja::hasher hasher1;
+        jinja::hasher hasher2;
+
+        std::string combined_data;
+        for (size_t size : sizes) {
+            std::string data = random_bytes(size);
+            hasher1.update(data);
+            combined_data += data;
+        }
+
+        hasher2.update(combined_data);
+        size_t hash1 = hasher1.digest();
+        size_t hash2 = hasher2.digest();
+        t.assert_true(
+            "Hashing in multiple updates should match single update with many chunks",
+            hash1 == hash2);
+    });
+
+    t.test("property: non associativity of update", [](testing & t) {
+        for (const auto & [size1, size2] : chunk_sizes) {
+            std::string data1 = random_bytes(size1);
+            std::string data2 = random_bytes(size2);
+
+            jinja::hasher hasher1;
+            hasher1.update(data1);
+            hasher1.update(data2);
+            size_t hash1 = hasher1.digest();
+
+            jinja::hasher hasher2;
+            hasher2.update(data2);
+            hasher2.update(data1);
+            size_t hash2 = hasher2.digest();
+
+            t.assert_true(
+                "Hashing order should matter (" + std::to_string(size1) + ", " + std::to_string(size2) + ")",
+                hash1 != hash2);
+        }
+    });
+
+    t.test("property: different lengths produce different hashes (padding block size)", [](testing & t) {
+        std::string random_data = random_bytes(64);
+
+        jinja::hasher hasher1;
+        hasher1.update(random_data);
+        size_t hash1 = hasher1.digest();
+
+        for (int i = 0; i < 16; ++i) {
+            random_data.push_back('A');  // change length
+            jinja::hasher hasher2;
+            hasher2.update(random_data);
+            size_t hash2 = hasher2.digest();
+
+            t.assert_true("Different lengths should produce different hashes (length " + std::to_string(random_data.size()) + ")", hash1 != hash2);
+
+            hash1 = hash2;
+        }
+    });
 }
 
 static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {