]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert.py : export rope freq_base when converting CodeLlama from an HF model (#2773)
authorslaren <redacted>
Fri, 25 Aug 2023 12:08:53 +0000 (14:08 +0200)
committerGitHub <redacted>
Fri, 25 Aug 2023 12:08:53 +0000 (14:08 +0200)
convert.py

index 10276bf6300317652919b36bdcd670a86e1557bc..e58ea46e0e7d3db5b70db509fdc5a7e12eebac19 100755 (executable)
@@ -160,13 +160,14 @@ class Params:
     def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
         config = json.load(open(config_path))
 
-        n_vocab    = config["vocab_size"]
-        n_embd     = config["hidden_size"]
-        n_layer    = config["num_hidden_layers"]
-        n_ff       = config["intermediate_size"]
-        n_head     = config["num_attention_heads"]
-        n_head_kv  = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
-        f_norm_eps = config["rms_norm_eps"]
+        n_vocab          = config["vocab_size"]
+        n_embd           = config["hidden_size"]
+        n_layer          = config["num_hidden_layers"]
+        n_ff             = config["intermediate_size"]
+        n_head           = config["num_attention_heads"]
+        n_head_kv        = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
+        f_norm_eps       = config["rms_norm_eps"]
+        f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
 
         n_mult = Params.find_n_mult(n_ff, n_embd)
 
@@ -179,15 +180,16 @@ class Params:
                             "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
 
         return Params(
-            n_vocab    = n_vocab,
-            n_embd     = n_embd,
-            n_mult     = n_mult,
-            n_layer    = n_layer,
-            n_ctx      = n_ctx,
-            n_ff       = n_ff,
-            n_head     = n_head,
-            n_head_kv  = n_head_kv,
-            f_norm_eps = f_norm_eps,
+            n_vocab          = n_vocab,
+            n_embd           = n_embd,
+            n_mult           = n_mult,
+            n_layer          = n_layer,
+            n_ctx            = n_ctx,
+            n_ff             = n_ff,
+            n_head           = n_head,
+            n_head_kv        = n_head_kv,
+            f_norm_eps       = f_norm_eps,
+            f_rope_freq_base = f_rope_freq_base,
         )
 
     # LLaMA v2 70B params.json