]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert : handle tokenizer merges format from transformers 4.45 (#9696)
authorcompilade <redacted>
Thu, 3 Oct 2024 14:22:15 +0000 (10:22 -0400)
committerGitHub <redacted>
Thu, 3 Oct 2024 14:22:15 +0000 (17:22 +0300)
gguf-py/gguf/vocab.py

index dc574991381a8a9558611cb729c4fb763dd4c444..f2645f92101dbd118172ac30a749add9ae4a394a 100644 (file)
@@ -122,8 +122,30 @@ class SpecialVocab:
                 tokenizer = json.load(f)
             if self.load_merges:
                 merges = tokenizer.get('model', {}).get('merges')
-                if isinstance(merges, list) and merges and isinstance(merges[0], str):
-                    self.merges = merges
+                if isinstance(merges, list) and merges:
+                    if isinstance(merges[0], str):
+                        self.merges = merges
+                    elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
+                        # New format since transformers 4.45 to support spaces in merges
+                        # ref: https://github.com/ggerganov/llama.cpp/issues/9692
+                        # TODO: internally store as the new format instead of converting to old
+                        if any(' ' in s for pair in merges for s in pair):
+                            logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
+                        self.merges = [
+                            ' '.join(
+                                [
+                                    # ensure the spaces are properly encoded
+                                    ''.join(
+                                        chr(ord(c) + 256) if c == ' ' else c
+                                        for c in part
+                                    )
+                                    for part in pair
+                                ]
+                            )
+                            for pair in merges
+                        ]
+                    else:
+                        raise ValueError("Unknown tokenizer merges format")
             added_tokens = tokenizer.get('added_tokens', {})
         else:
             added_tokens = {}