]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
models : minor changes to the HF convert script (#157)
authorGeorgi Gerganov <redacted>
Wed, 23 Nov 2022 20:07:20 +0000 (22:07 +0200)
committerGeorgi Gerganov <redacted>
Wed, 23 Nov 2022 20:07:20 +0000 (22:07 +0200)
models/convert-h5-to-ggml.py

index 61d29e81ecd168dff4d32754bde691453a5c6b02..7fef7ad2d5fc5cee5f4498761aa29268d79789c0 100644 (file)
@@ -28,6 +28,7 @@ conv_map = {'self_attn_layer_norm': 'attn_ln',
  'decoder.layer_norm.weight': 'decoder.ln.weight',
  'decoder.embed_positions.weight': 'decoder.positional_embedding',
  'decoder.embed_tokens.weight': 'decoder.token_embedding.weight',
+ 'proj_out.weight': 'decoder.proj.weight',
 }
 
 # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
@@ -82,8 +83,11 @@ fname_out = dir_out + "/ggml-model.bin"
 with open(dir_tokenizer + "/vocab.json", "r", encoding="utf8") as f:
     tokens = json.load(f)
 
-
+# use 16-bit or 32-bit floats
 use_f16 = True
+if len(sys.argv) > 4:
+    use_f16 = False
+    fname_out = dir_out + "/ggml-model-f32.bin"
 
 fout = open(fname_out, "wb")
 
@@ -119,6 +123,8 @@ for key in tokens:
 
 list_vars = model.state_dict()
 for name in list_vars.keys():
+    # this seems to not be used
+    # ref: https://github.com/huggingface/transformers/blob/9a5b84a0076a04fe9596da72e8668069d4f09ea0/src/transformers/models/whisper/modeling_whisper.py#L1099-L1106
     if name == "proj_out.weight":
         print('Skipping', name)
         continue
@@ -126,7 +132,11 @@ for name in list_vars.keys():
     src = name
 
     nn = name
-    nn = nn.split(".")[1:]
+    if name != "proj_out.weight":
+        nn = nn.split(".")[1:]
+    else:
+        nn = nn.split(".")
+
     if nn[1] == "layers":
         nn[1] = "blocks"
         if ".".join(nn[3:-1]) == "self_attn.k_proj":