]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
convert : handle max_target_positions (#2477)
authorCrispStrobe <redacted>
Mon, 14 Oct 2024 07:46:33 +0000 (09:46 +0200)
committerGitHub <redacted>
Mon, 14 Oct 2024 07:46:33 +0000 (10:46 +0300)
as needed eg for
https://huggingface.co/primeline/whisper-large-v3-turbo-german/blob/main/config.json

models/convert-h5-to-ggml.py

index 50836a216fb8d954f4afdacb062b2e049b621631..5474d58613ac343b40662a4e947f5560ea6cbf07 100644 (file)
@@ -82,7 +82,11 @@ dir_out     = Path(sys.argv[3])
 
 encoder = json.load((dir_model / "vocab.json").open("r", encoding="utf8"))
 encoder_added = json.load((dir_model / "added_tokens.json").open( "r", encoding="utf8"))
-hparams = json.load((dir_model / "config.json").open("r", encoding="utf8") )
+hparams = json.load((dir_model / "config.json").open("r", encoding="utf8"))
+
+# Add this block to handle missing 'max_length'
+if "max_length" not in hparams:
+    hparams["max_length"] = hparams.get("max_target_positions", 448)
 
 model = WhisperForConditionalGeneration.from_pretrained(dir_model)