]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert : fix vocab padding code for bert models (#13954)
authorSigbjørn Skjæret <redacted>
Sun, 1 Jun 2025 15:23:11 +0000 (17:23 +0200)
committerGitHub <redacted>
Sun, 1 Jun 2025 15:23:11 +0000 (17:23 +0200)
convert_hf_to_gguf.py

index ab0f0e0ea087e8776750ea1e6eeb282c6340a691..42e8f9cc06e293a52cda118e720fff80475d69a9 100755 (executable)
@@ -3814,7 +3814,7 @@ class BertModel(TextModel):
             remove_whitespaces = tokenizer.clean_up_tokenization_spaces
             precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
 
-            vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size)
+            vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size)
         else:
             sentencepiece_model = model.ModelProto()  # pyright: ignore[reportAttributeAccessIssue]
             sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
@@ -3827,7 +3827,7 @@ class BertModel(TextModel):
             tokenizer = SentencePieceProcessor()
             tokenizer.LoadFromFile(str(tokenizer_path))
 
-            vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
+            vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size())
 
         tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
         scores: list[float] = [-10000.0] * vocab_size
@@ -3857,33 +3857,26 @@ class BertModel(TextModel):
             unk_token = tokenizer_config_json.get("unk_token")
             unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3))
 
-            for token_id in range(vocab_size):
+            for token_id in range(tokenizer.vocab_size):
                 piece = tokenizer._convert_id_to_token(token_id)
-                text = piece.encode("utf-8")
-                score = tokenizer_json["model"]["vocab"][token_id][1]
-
-                toktype = SentencePieceTokenTypes.NORMAL
-                if token_id == unk_token_id:
-                    toktype = SentencePieceTokenTypes.UNKNOWN
-                elif token_id in tokenizer.all_special_ids:
-                    toktype = SentencePieceTokenTypes.CONTROL
-                elif token_id in added_vocab.values():
-                    toktype = SentencePieceTokenTypes.USER_DEFINED
-                # No reliable way to detect this, but jina doesn't have any
-                # elif tokenizer.IsByte(token_id):
-                #     toktype = SentencePieceTokenTypes.BYTE
-
-                tokens[token_id] = text
-                scores[token_id] = score
-                toktypes[token_id] = toktype
-
-        if vocab_size > len(tokens):
-            pad_count = vocab_size - len(tokens)
-            logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
-            for i in range(1, pad_count + 1):
-                tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
-                scores.append(-1000.0)
-                toktypes.append(SentencePieceTokenTypes.UNUSED)
+                if (piece := tokenizer._convert_id_to_token(token_id)) is not None:
+                    text = piece.encode("utf-8")
+                    score = tokenizer_json["model"]["vocab"][token_id][1]
+
+                    toktype = SentencePieceTokenTypes.NORMAL
+                    if token_id == unk_token_id:
+                        toktype = SentencePieceTokenTypes.UNKNOWN
+                    elif token_id in tokenizer.all_special_ids:
+                        toktype = SentencePieceTokenTypes.CONTROL
+                    elif token_id in added_vocab.values():
+                        toktype = SentencePieceTokenTypes.USER_DEFINED
+                    # No reliable way to detect this, but jina doesn't have any
+                    # elif tokenizer.IsByte(token_id):
+                    #     toktype = SentencePieceTokenTypes.BYTE
+
+                    tokens[token_id] = text
+                    scores[token_id] = score
+                    toktypes[token_id] = toktype
 
         if isinstance(tokenizer, SentencePieceProcessor):
             # realign tokens (see HF tokenizer code)