]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : allow overriding GGUF metadata when loading model (#4092)
authorKerfuffle <redacted>
Tue, 5 Dec 2023 17:19:18 +0000 (10:19 -0700)
committerGitHub <redacted>
Tue, 5 Dec 2023 17:19:18 +0000 (19:19 +0200)
* feat: Allow overriding GGUF metadata when loading model

* Fix the one time GCC is stricter than clang about something

* Step1

* Refactor... basically everything!

* Nuke obsolete GetArrayLen struct

* simplify std::string specialization

* Various cleanups

Add informational output when overrides are applied

Warn user when an override with the wrong type is specified

* Fix broken logic for parsing bool KV overrides
Fix issue where overrides didn't apply when key missing in GGUF metadata
Resolve merge changes

* llama : rearrange model params

* Update new GET_KEY call

Add note that metadata KV overrides aren't reflected in initial metadata KV info dump

---------

Co-authored-by: cebtenzzre <redacted>
Co-authored-by: Georgi Gerganov <redacted>
common/common.cpp
common/common.h
llama.cpp
llama.h

index 8e6d74d0d704a3cb2fb3b8986c69171171a99d80..4e823c526e2e6b799dde84327eecf64e6978b349 100644 (file)
@@ -690,6 +690,47 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 std::istreambuf_iterator<char>(),
                 std::back_inserter(sparams.grammar)
             );
+        } else if (arg == "--override-kv") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            char * sep = strchr(argv[i], '=');
+            if (sep == nullptr || sep - argv[i] >= 128) {
+                fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
+                invalid_param = true;
+                break;
+            }
+            struct llama_model_kv_override kvo;
+            std::strncpy(kvo.key, argv[i], sep - argv[i]);
+            kvo.key[sep - argv[i]] = 0;
+            sep++;
+            if (strncmp(sep, "int:", 4) == 0) {
+                sep += 4;
+                kvo.tag = LLAMA_KV_OVERRIDE_INT;
+                kvo.int_value = std::atol(sep);
+            } else if (strncmp(sep, "float:", 6) == 0) {
+                sep += 6;
+                kvo.tag = LLAMA_KV_OVERRIDE_FLOAT;
+                kvo.float_value = std::atof(sep);
+            } else if (strncmp(sep, "bool:", 5) == 0) {
+                sep += 5;
+                kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
+                if (std::strcmp(sep, "true") == 0) {
+                    kvo.bool_value = true;
+                } else if (std::strcmp(sep, "false") == 0) {
+                    kvo.bool_value = false;
+                } else {
+                    fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
+                    invalid_param = true;
+                    break;
+                }
+            } else {
+                fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
+                invalid_param = true;
+                break;
+            }
+            params.kv_overrides.push_back(kvo);
 #ifndef LOG_DISABLE_LOGS
         // Parse args for logging parameters
         } else if ( log_param_single_parse( argv[i] ) ) {
@@ -733,6 +774,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
         }
     }
 
+    if (!params.kv_overrides.empty()) {
+        params.kv_overrides.emplace_back(llama_model_kv_override());
+        params.kv_overrides.back().key[0] = 0;
+    }
+
     return true;
 }
 
@@ -864,6 +910,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("                        draft model for speculative decoding (default: %s)\n", params.model.c_str());
     printf("  -ld LOGDIR, --logdir LOGDIR\n");
     printf("                        path under which to save YAML logs (no logging if unset)\n");
+    printf("  --override-kv KEY=TYPE:VALUE\n");
+    printf("                        advanced option to override model metadata by key. may be specified multiple times.\n");
+    printf("                        types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
     printf("\n");
 #ifndef LOG_DISABLE_LOGS
     log_print_usage();
@@ -956,6 +1005,12 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
     mparams.tensor_split    = params.tensor_split;
     mparams.use_mmap        = params.use_mmap;
     mparams.use_mlock       = params.use_mlock;
+    if (params.kv_overrides.empty()) {
+        mparams.kv_overrides = NULL;
+    } else {
+        GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
+        mparams.kv_overrides = params.kv_overrides.data();
+    }
 
     return mparams;
 }
index 534f7b1322da2adaa0a88b835e95a385af7065e1..02467938061b2e7d213a413b34ee4d9bf460271f 100644 (file)
@@ -86,6 +86,8 @@ struct gpt_params {
     std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
     std::string logdir            = "";  // directory in which to save YAML log files
 
+    std::vector<llama_model_kv_override> kv_overrides;
+
     // TODO: avoid tuple, use struct
     std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
     std::string lora_base  = "";                              // base model path for the lora adapter
index fd905ade7a73bccb9299afae8680196b18cd7516..b77020e10d8a5f467604041bbd54e0805ba8478d 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -74,6 +74,7 @@
 #include <set>
 #include <sstream>
 #include <thread>
+#include <type_traits>
 #include <unordered_map>
 
 #if defined(_MSC_VER)
@@ -590,21 +591,6 @@ struct LLM_TN {
 // gguf helpers
 //
 
-#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
-do { \
-    const std::string skey(key); \
-    const int kid = gguf_find_key(ctx, skey.c_str()); \
-    if (kid >= 0) { \
-        enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
-        if (ktype != (type)) { \
-            throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \
-        } \
-        (dst) = func(ctx, kid); \
-    } else if (req) { \
-        throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \
-    } \
-} while (0)
-
 static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
     { LLAMA_ROPE_SCALING_NONE,   "none"   },
     { LLAMA_ROPE_SCALING_LINEAR, "linear" },
@@ -638,7 +624,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
     }
 }
 
-static std::string gguf_kv_to_str(struct gguf_context * ctx_gguf, int i) {
+static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
     const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
 
     switch (type) {
@@ -1797,6 +1783,169 @@ static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
     return buf;
 }
 
+namespace GGUFMeta {
+    template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int)>
+    struct GKV_Base_Type {
+        static constexpr gguf_type gt = gt_;
+
+        static T getter(const gguf_context * ctx, const int kid) {
+            return gfun(ctx, kid);
+        }
+    };
+
+    template<typename T> struct GKV_Base;
+
+    template<> struct GKV_Base<bool        >: GKV_Base_Type<bool,         GGUF_TYPE_BOOL,    gguf_get_val_bool> {};
+    template<> struct GKV_Base<uint8_t     >: GKV_Base_Type<uint8_t,      GGUF_TYPE_UINT8,   gguf_get_val_u8  > {};
+    template<> struct GKV_Base<uint16_t    >: GKV_Base_Type<uint16_t,     GGUF_TYPE_UINT16,  gguf_get_val_u16 > {};
+    template<> struct GKV_Base<uint32_t    >: GKV_Base_Type<uint32_t,     GGUF_TYPE_UINT32,  gguf_get_val_u32 > {};
+    template<> struct GKV_Base<uint64_t    >: GKV_Base_Type<uint64_t,     GGUF_TYPE_UINT64,  gguf_get_val_u64 > {};
+    template<> struct GKV_Base<int8_t      >: GKV_Base_Type<int8_t,       GGUF_TYPE_INT8,    gguf_get_val_i8  > {};
+    template<> struct GKV_Base<int16_t     >: GKV_Base_Type<int16_t,      GGUF_TYPE_INT16,   gguf_get_val_i16 > {};
+    template<> struct GKV_Base<int32_t     >: GKV_Base_Type<int32_t,      GGUF_TYPE_INT32,   gguf_get_val_i32 > {};
+    template<> struct GKV_Base<int64_t     >: GKV_Base_Type<int64_t,      GGUF_TYPE_INT64,   gguf_get_val_i64 > {};
+    template<> struct GKV_Base<float       >: GKV_Base_Type<float,        GGUF_TYPE_FLOAT32, gguf_get_val_f32 > {};
+    template<> struct GKV_Base<double      >: GKV_Base_Type<double,       GGUF_TYPE_FLOAT64, gguf_get_val_f64 > {};
+    template<> struct GKV_Base<const char *>: GKV_Base_Type<const char *, GGUF_TYPE_STRING,  gguf_get_val_str > {};
+
+    template<> struct GKV_Base<std::string> {
+        static constexpr gguf_type gt = GGUF_TYPE_STRING;
+
+        static std::string getter(const gguf_context * ctx, const int kid) {
+            return gguf_get_val_str(ctx, kid);
+        }
+    };
+
+    struct ArrayInfo{
+        const gguf_type gt;
+        const size_t length;
+        const void * data;
+    };
+
+    template<> struct GKV_Base<ArrayInfo> {
+        public:
+        static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
+        static ArrayInfo getter(const gguf_context *ctx, const int k) {
+            return ArrayInfo {
+                gguf_get_arr_type(ctx, k),
+                size_t(gguf_get_arr_n(ctx, k)),
+                gguf_get_arr_data(ctx, k),
+            };
+        }
+    };
+
+    template<typename T>
+    class GKV: public GKV_Base<T> {
+        GKV() = delete;
+
+        public:
+        static T get_kv(const gguf_context * ctx, const int k) {
+            const enum gguf_type kt = gguf_get_kv_type(ctx, k);
+
+            if (kt != GKV::gt) {
+                throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
+                    gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt)));
+            }
+            return GKV::getter(ctx, k);
+        }
+
+        static const char * override_type_to_str(const llama_model_kv_override_type ty) {
+            switch (ty) {
+                case LLAMA_KV_OVERRIDE_BOOL:  return "bool";
+                case LLAMA_KV_OVERRIDE_INT:   return "int";
+                case LLAMA_KV_OVERRIDE_FLOAT: return "float";
+            }
+            return "unknown";
+        }
+
+        static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override *override) {
+            if (!override) { return false; }
+            if (override->tag == expected_type) {
+                LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
+                    __func__, override_type_to_str(override->tag), override->key);
+                switch (override->tag) {
+                    case LLAMA_KV_OVERRIDE_BOOL:  {
+                        printf("%s\n", override->bool_value ? "true" : "false");
+                    } break;
+                    case LLAMA_KV_OVERRIDE_INT:   {
+                        printf("%" PRId64 "\n", override->int_value);
+                    } break;
+                    case LLAMA_KV_OVERRIDE_FLOAT: {
+                        printf("%.6f\n", override->float_value);
+                    } break;
+                    default:
+                        // Shouldn't be possible to end up here, but just in case...
+                        throw std::runtime_error(
+                            format("Unsupported attempt to override %s type for metadata key %s\n",
+                                override_type_to_str(override->tag), override->key));
+                }
+                return true;
+            }
+            LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
+                __func__, override->key, override_type_to_str(expected_type), override_type_to_str(override->tag));
+            return false;
+        }
+
+        template<typename OT>
+        static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
+        try_override(OT & target, const struct llama_model_kv_override *override) {
+            if (validate_override(LLAMA_KV_OVERRIDE_BOOL, override)) {
+                target = override->bool_value;
+                return true;
+            }
+            return true;
+        }
+
+        template<typename OT>
+        static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
+        try_override(OT & target, const struct llama_model_kv_override *override) {
+            if (validate_override(LLAMA_KV_OVERRIDE_INT, override)) {
+                target = override->int_value;
+                return true;
+            }
+            return false;
+        }
+
+        template<typename OT>
+        static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
+        try_override(T & target, const struct llama_model_kv_override *override) {
+            if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, override)) {
+                target = override->float_value;
+                return true;
+            }
+            return false;
+        }
+
+        template<typename OT>
+        static typename std::enable_if<std::is_same<OT, std::string>::value, bool>::type
+        try_override(T & target, const struct llama_model_kv_override *override) {
+            (void)target;
+            (void)override;
+            if (!override) { return false; }
+            // Currently, we should never end up here so it would be a bug if we do.
+            throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n",
+                override ? override->key : "NULL"));
+        }
+
+        static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override *override = nullptr) {
+            if (try_override<T>(target, override)) {
+                return true;
+            }
+            if (k < 0) { return false; }
+            target = get_kv(ctx, k);
+            return true;
+        }
+
+        static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) {
+            return set(ctx, gguf_find_key(ctx, key), target, override);
+        }
+
+        static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) {
+            return set(ctx, key.c_str(), target, override);
+        }
+    };
+}
+
 struct llama_model_loader {
     int n_kv      = 0;
     int n_tensors = 0;
@@ -1812,21 +1961,34 @@ struct llama_model_loader {
     llama_fver  fver;
 
     std::unique_ptr<llama_mmap> mapping;
+    std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
 
     struct gguf_context * ctx_gguf = NULL;
     struct ggml_context * ctx_meta = NULL;
 
-    llama_model_loader(const std::string & fname, bool use_mmap) : file(fname.c_str(), "rb") {
+    std::string arch_name;
+    LLM_KV      llm_kv    = LLM_KV(LLM_ARCH_UNKNOWN);
+
+    llama_model_loader(const std::string & fname, bool use_mmap, const struct llama_model_kv_override * param_overrides_p) : file(fname.c_str(), "rb") {
         struct gguf_init_params params = {
             /*.no_alloc = */ true,
             /*.ctx      = */ &ctx_meta,
         };
 
+        if (param_overrides_p != nullptr) {
+            for (const struct llama_model_kv_override *p = param_overrides_p; p->key[0] != 0; p++) {
+                kv_overrides.insert({std::string(p->key), *p});
+            }
+        }
+
         ctx_gguf = gguf_init_from_file(fname.c_str(), params);
         if (!ctx_gguf) {
             throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
         }
 
+        get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
+        llm_kv = LLM_KV(llm_arch_from_string(arch_name));
+
         n_kv      = gguf_get_n_kv(ctx_gguf);
         n_tensors = gguf_get_n_tensors(ctx_gguf);
 
@@ -1894,6 +2056,7 @@ struct llama_model_loader {
                 }
             }
 
+            LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
             for (int i = 0; i < n_kv; i++) {
                 const char * name           = gguf_get_key(ctx_gguf, i);
                 const enum gguf_type type   = gguf_get_kv_type(ctx_gguf, i);
@@ -1939,19 +2102,59 @@ struct llama_model_loader {
         }
     }
 
-    std::string get_arch_name() const {
-        const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
+    template<typename T>
+    typename std::enable_if<std::is_integral<T>::value, bool>::type
+    get_arr_n(const std::string & key, T & result, const bool required = true) {
+        const int kid = gguf_find_key(ctx_gguf, key.c_str());
 
-        std::string arch_name;
-        GGUF_GET_KEY(ctx_gguf, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE));
+        if (kid < 0) {
+            if (required) {
+                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
+            }
+            return false;
+        }
 
+        struct GGUFMeta::ArrayInfo arr_info =
+            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx_gguf, kid);
+
+
+        result = arr_info.length;
+        return true;
+    }
+
+    template<typename T>
+    typename std::enable_if<std::is_integral<T>::value, bool>::type
+    get_arr_n(const enum llm_kv kid, T & result, const bool required = true) {
+        return get_arr_n(llm_kv(kid), result, required);
+    }
+
+    template<typename T>
+    bool get_key(const std::string & key, T & result, const bool required = true) {
+        auto it = kv_overrides.find(key);
+
+        const struct llama_model_kv_override * override =
+            it != kv_overrides.end() ? &it->second : nullptr;
+
+        const bool found = GGUFMeta::GKV<T>::set(ctx_gguf, key, result, override);
+
+        if (required && !found) {
+            throw std::runtime_error(format("key not found in model: %s", key.c_str()));
+        }
+
+        return found;
+    }
+
+    template<typename T>
+    bool get_key(const enum llm_kv kid, T & result, const bool required = true) {
+        return get_key(llm_kv(kid), result, required);
+    }
+
+    std::string get_arch_name() const {
         return arch_name;
     }
 
     enum llm_arch get_arch() const {
-        const std::string arch_name = get_arch_name();
-
-        return llm_arch_from_string(arch_name);
+        return llm_kv.arch;
     }
 
     const char * get_tensor_name(int i) const {
@@ -2201,11 +2404,8 @@ static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
 static void llm_load_hparams(
         llama_model_loader & ml,
         llama_model & model) {
-    struct gguf_context * ctx = ml.ctx_gguf;
-
-    const auto kv = LLM_KV(model.arch);
-
     auto & hparams = model.hparams;
+    const gguf_context * ctx = ml.ctx_gguf;
 
     // get metadata as string
     for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
@@ -2219,42 +2419,41 @@ static void llm_load_hparams(
     }
 
     // get general kv
-    GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
+    ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
 
     // get hparams kv
-    GGUF_GET_KEY(ctx, hparams.n_vocab,        gguf_get_arr_n,   GGUF_TYPE_ARRAY,  true, kv(LLM_KV_TOKENIZER_LIST));
-    GGUF_GET_KEY(ctx, hparams.n_ctx_train,    gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH));
-    GGUF_GET_KEY(ctx, hparams.n_embd,         gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
-    GGUF_GET_KEY(ctx, hparams.n_ff,           gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
-    GGUF_GET_KEY(ctx, hparams.n_head,         gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
-    GGUF_GET_KEY(ctx, hparams.n_layer,        gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
+    ml.get_arr_n(LLM_KV_TOKENIZER_LIST,       hparams.n_vocab);
+    ml.get_key  (LLM_KV_CONTEXT_LENGTH,       hparams.n_ctx_train);
+    ml.get_key  (LLM_KV_EMBEDDING_LENGTH,     hparams.n_embd);
+    ml.get_key  (LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff);
+    ml.get_key  (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
+    ml.get_key  (LLM_KV_BLOCK_COUNT,          hparams.n_layer);
 
     // n_head_kv is optional, default to n_head
     hparams.n_head_kv = hparams.n_head;
-    GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
+    ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false);
 
-    hparams.rope_finetuned = false;
-    GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false,
-                 kv(LLM_KV_ROPE_SCALING_FINETUNED));
+    bool rope_finetuned = false;
+    ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
+    hparams.rope_finetuned = rope_finetuned;
 
     hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
-    GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
-                 kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN));
+    ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_yarn_orig_ctx, false);
 
     // rope_freq_base (optional)
     hparams.rope_freq_base_train = 10000.0f;
-    GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
+    ml.get_key(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train, false);
 
     std::string rope_scaling("linear");
-    GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE));
+    ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false);
     hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
     GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
 
     // rope_freq_scale (inverse of the kv) is optional
     float ropescale = 0.0f;
-    GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
-    if (ropescale == 0.0f) { // try the old key name
-        GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+    if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) {
+        // try the old key name
+        ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false);
     }
     hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
 
@@ -2262,7 +2461,7 @@ static void llm_load_hparams(
     {
         hparams.n_rot = hparams.n_embd / hparams.n_head;
 
-        GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
+        ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
 
         if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
             if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
@@ -2277,7 +2476,7 @@ static void llm_load_hparams(
     switch (model.arch) {
         case LLM_ARCH_LLAMA:
             {
-                GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
                 switch (hparams.n_layer) {
                     case 26: model.type = e_model::MODEL_3B; break;
@@ -2291,7 +2490,7 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_FALCON:
             {
-                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
 
                 switch (hparams.n_layer) {
                     case 32: model.type = e_model::MODEL_7B; break;
@@ -2301,7 +2500,7 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_BAICHUAN:
             {
-                GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 switch (hparams.n_layer) {
                     case 32: model.type = e_model::MODEL_7B; break;
                     case 40: model.type = e_model::MODEL_13B; break;
@@ -2310,7 +2509,7 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_STARCODER:
             {
-                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
                 switch (hparams.n_layer) {
                     case 24: model.type = e_model::MODEL_1B; break;
                     case 36: model.type = e_model::MODEL_3B; break;
@@ -2321,7 +2520,7 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_PERSIMMON:
             {
-                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
                 switch (hparams.n_layer) {
                     case 36: model.type = e_model::MODEL_8B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
@@ -2329,7 +2528,7 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_REFACT:
             {
-                GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 switch (hparams.n_layer) {
                     case 32: model.type = e_model::MODEL_1B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
@@ -2337,7 +2536,7 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_BLOOM:
             {
-                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
 
                 switch (hparams.n_layer) {
                     case 24: model.type = e_model::MODEL_1B; break;
@@ -2352,9 +2551,9 @@ static void llm_load_hparams(
             {
                 hparams.f_clamp_kqv = 0.0f;
 
-                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
-                GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV));
-                GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS));
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,      hparams.f_clamp_kqv, false);
+                ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
 
                 switch (hparams.n_layer) {
                     case 32: model.type = e_model::MODEL_7B; break;
@@ -2364,7 +2563,7 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_STABLELM:
             {
-                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
 
                 switch (hparams.n_layer) {
                     case 32: model.type = e_model::MODEL_3B; break;
@@ -2373,7 +2572,8 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_QWEN:
             {
-                GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
                 switch (hparams.n_layer) {
                     case 32: model.type = e_model::MODEL_7B; break;
                     case 40: model.type = e_model::MODEL_13B; break;
@@ -2421,7 +2621,7 @@ static void llm_load_vocab(
     {
         std::string tokenizer_name;
 
-        GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
+        ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name);
 
         if (tokenizer_name == "llama") {
             vocab.type = LLAMA_VOCAB_TYPE_SPM;
@@ -2511,34 +2711,31 @@ static void llm_load_vocab(
         };
         for (const auto & it : special_token_types) {
             const std::string & key = kv(std::get<0>(it));
-            int32_t & id = std::get<1>(it), old_id = id;
+            int32_t & id = std::get<1>(it);
 
-            GGUF_GET_KEY(ctx, id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, key);
-            // Must be >= -1 and < vocab size. Since the key is unsigned, -1
-            // can only come from the default value, so there's no point in
-            // validating that.
-            if (size_t(id + 1) > vocab.id_to_token.size()) {
-                LLAMA_LOG_WARN("%s: bad special token: '%s' = %d, using default id %d\n",
-                    __func__, key.c_str(), id, old_id);
-                id = old_id;
+            uint32_t new_id;
+            if (!ml.get_key(std::get<0>(it), new_id, false)) {
+                continue;
+            }
+            if (new_id >= vocab.id_to_token.size()) {
+                LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
+                    __func__, key.c_str(), new_id, id);
+            } else {
+                id = new_id;
             }
 
         }
 
         // Handle add_bos_token and add_eos_token
-        std::string key = kv(LLM_KV_TOKENIZER_ADD_BOS);
-        int kid = gguf_find_key(ctx, key.c_str());
-        enum gguf_type ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
-        vocab.special_add_bos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
-        if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
-            LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
-        }
-        key = kv(LLM_KV_TOKENIZER_ADD_EOS);
-        kid = gguf_find_key(ctx, key.c_str());
-        ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
-        vocab.special_add_eos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
-        if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
-            LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
+        {
+            bool temp = true;
+
+            if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
+                vocab.special_add_bos = int(temp);
+            }
+            if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
+                vocab.special_add_eos = int(temp);
+            }
         }
     }
 
@@ -3487,7 +3684,7 @@ static void llm_load_tensors(
 
 static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
     try {
-        llama_model_loader ml(fname, params.use_mmap);
+        llama_model_loader ml(fname, params.use_mmap, params.kv_overrides);
 
         model.hparams.vocab_only = params.vocab_only;
 
@@ -8078,7 +8275,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     constexpr bool use_mmap = false;
 #endif
 
-    llama_model_loader ml(fname_inp, use_mmap);
+    llama_model_loader ml(fname_inp, use_mmap, NULL);
     if (ml.use_mmap) {
         ml.mapping.reset(new llama_mmap(&ml.file, /* prefetch */ 0, ggml_is_numa()));
     }
@@ -8374,7 +8571,7 @@ static int llama_apply_lora_from_file_internal(
     std::vector<uint8_t> base_buf;
     if (path_base_model) {
         LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
-        ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true));
+        ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ NULL));
 
         size_t ctx_size;
         size_t mmapped_size;
@@ -8602,6 +8799,7 @@ struct llama_model_params llama_model_default_params() {
         /*.tensor_split                =*/ nullptr,
         /*.progress_callback           =*/ nullptr,
         /*.progress_callback_user_data =*/ nullptr,
+        /*.kv_overrides                =*/ nullptr,
         /*.vocab_only                  =*/ false,
         /*.use_mmap                    =*/ true,
         /*.use_mlock                   =*/ false,
diff --git a/llama.h b/llama.h
index 89cb6198e84b8c3e9c72abafdb00331460e8072f..517245a3543004969306984ebcd8f147d90373c8 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -158,6 +158,22 @@ extern "C" {
         llama_seq_id all_seq_id; // used if seq_id == NULL
     } llama_batch;
 
+    enum llama_model_kv_override_type {
+        LLAMA_KV_OVERRIDE_INT,
+        LLAMA_KV_OVERRIDE_FLOAT,
+        LLAMA_KV_OVERRIDE_BOOL,
+    };
+
+    struct llama_model_kv_override {
+        char key[128];
+        enum llama_model_kv_override_type tag;
+        union {
+            int64_t int_value;
+            double float_value;
+            bool bool_value;
+        };
+    };
+
     struct llama_model_params {
         int32_t n_gpu_layers; // number of layers to store in VRAM
         int32_t main_gpu;     // the GPU that is used for scratch and small tensors
@@ -165,9 +181,13 @@ extern "C" {
 
         // called with a progress value between 0 and 1, pass NULL to disable
         llama_progress_callback progress_callback;
+
         // context pointer passed to the progress callback
         void * progress_callback_user_data;
 
+        // override key-value pairs of the model meta data
+        const struct llama_model_kv_override * kv_overrides;
+
         // Keep the booleans together to avoid misalignment during copy-by-value.
         bool vocab_only; // only load the vocabulary, no weights
         bool use_mmap;   // use mmap if possible