]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gguf-py : add special token modification capability (#7166)
authorSigbjørn Skjæret <redacted>
Thu, 9 May 2024 10:56:00 +0000 (12:56 +0200)
committerGitHub <redacted>
Thu, 9 May 2024 10:56:00 +0000 (13:56 +0300)
* Add special token modification capability

To be able to fix/amend special tokens in a GGUF let's add two new arguments:
* `--special-token <name> <value>` where `<name>` can be bos, eos, prefix, middle, etc. while `<value>` is the token value, f.ex. `"<|fim▁begin|>"`
* `--special-token-by-id <name> <id>` where `<id>` is the ID of the token, f.ex. 32006

So, in order to f.ex. add fill-in-middle tokens to a GGUF you would do the following:
```bash
python3 gguf-new-metadata.py input.gguf output.gguf --special-token prefix "<|fim▁begin|>" --special-token middle "<|fim▁hole|>" --special-token suffix "<|fim▁end|>"
```

* improve help text

* flake--

* fix multiple tokens warning

* make script executable

* switch to namedtuple, no need to dataclass

* typing++

* add progress bar

* Add special token modification capability

To be able to fix/amend special tokens in a GGUF let's add two new arguments:
* `--special-token <name> <value>` where `<name>` can be bos, eos, prefix, middle, etc. while `<value>` is the token value, f.ex. `"<|fim▁begin|>"`
* `--special-token-by-id <name> <id>` where `<id>` is the ID of the token, f.ex. 32006

So, in order to f.ex. add fill-in-middle tokens to a GGUF you would do the following:
```bash
gguf-new-metadata.py input.gguf output.gguf --special-token prefix "<|fim▁begin|>" --special-token middle "<|fim▁end|>" --special-token suffix "<|fim▁hole|>"
```
(yes, fim_end is the `middle` token, because completion is a `prefix`/`suffix`/`middle` sequence (where `middle` is unfilled))
or
```bash
gguf-new-metadata.py input.gguf output.gguf --special-token prefix "<fim_prefix>" --special-token middle "<fim_middle>" --special-token suffix "<fim_suffix>"
```
etc...

NB: The tokens have to exist already, trying to add non-existent token name/IDs will be ignored (with a warning), while non-existent values will fail (with an error).

* improve help text

* flake--

* fix multiple tokens warning

* make script executable

* switch to namedtuple, no need to dataclass

* typing++

* add progress bar

* fail on invalid token id

gguf-py/scripts/gguf-new-metadata.py [changed mode: 0644->0755]

old mode 100644 (file)
new mode 100755 (executable)
index c8e3a83..63d3c5d
@@ -7,7 +7,8 @@ import json
 from pathlib import Path
 
 import numpy as np
-from typing import Any, Sequence
+from tqdm import tqdm
+from typing import Any, Sequence, NamedTuple
 
 # Necessary to load the local gguf package
 if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
@@ -18,6 +19,12 @@ import gguf
 logger = logging.getLogger("gguf-new-metadata")
 
 
+class MetadataDetails(NamedTuple):
+    type: gguf.GGUFValueType
+    value: Any
+    description: str = ''
+
+
 def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
     if np.uint32(1) == np.uint32(1).newbyteorder("<"):
         # Host is little endian
@@ -59,7 +66,16 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
     return decode_field(field)
 
 
-def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, str], remove_metadata: Sequence[str]) -> None:
+def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
+    token_ids = [index for index, value in enumerate(token_list) if value == token]
+
+    if len(token_ids) == 0:
+        raise LookupError(f'Unable to find "{token}" in token list!')
+
+    return token_ids
+
+
+def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, MetadataDetails], remove_metadata: Sequence[str]) -> None:
     for field in reader.fields.values():
         # Suppress virtual fields and fields written by GGUFWriter
         if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
@@ -75,54 +91,64 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
             logger.debug(f'Removing {field.name}')
             continue
 
-        old_val = decode_field(field)
+        old_val = MetadataDetails(field.types[0], decode_field(field))
         val = new_metadata.get(field.name, old_val)
 
         if field.name in new_metadata:
-            logger.debug(f'Modifying {field.name}: "{old_val}" -> "{val}"')
+            logger.debug(f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}')
             del new_metadata[field.name]
-        elif val is not None:
+        elif val.value is not None:
             logger.debug(f'Copying {field.name}')
 
-        if val is not None:
+        if val.value is not None:
             writer.add_key(field.name)
-            writer.add_val(val, field.types[0])
+            writer.add_val(val.value, val.type)
 
     if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
         logger.debug('Adding chat template(s)')
-        writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE])
+        writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE].value)
         del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]
 
-    # TODO: Support other types than string?
     for key, val in new_metadata.items():
-        logger.debug(f'Adding {key}: {val}')
+        logger.debug(f'Adding {key}: "{val.value}" {val.description}')
         writer.add_key(key)
-        writer.add_val(val, gguf.GGUFValueType.STRING)
+        writer.add_val(val.value, val.type)
+
+    total_bytes = 0
 
     for tensor in reader.tensors:
+        total_bytes += tensor.n_bytes
         # Dimensions are written in reverse order, so flip them first
         shape = np.flipud(tensor.shape).tolist()
         writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
 
+    bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
+
     writer.write_header_to_file()
     writer.write_kv_data_to_file()
     writer.write_ti_data_to_file()
 
     for tensor in reader.tensors:
         writer.write_tensor_data(tensor.data)
+        bar.update(tensor.n_bytes)
 
     writer.close()
 
 
 def main() -> None:
+    tokenizer_metadata = (getattr(gguf.Keys.Tokenizer, n) for n in gguf.Keys.Tokenizer.__dict__.keys() if not n.startswith('_'))
+    token_names = dict((n.split('.')[-1][:-len('_token_id')], n) for n in tokenizer_metadata if n.endswith('_token_id'))
+
     parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata")
     parser.add_argument("input",                                       type=Path, help="GGUF format model input filename")
     parser.add_argument("output",                                      type=Path, help="GGUF format model output filename")
-    parser.add_argument("--general-name",                              type=str,  help="The models general.name")
-    parser.add_argument("--general-description",                       type=str,  help="The models general.description")
-    parser.add_argument("--chat-template",                             type=str,  help="Chat template string (or JSON string containing templates)")
-    parser.add_argument("--chat-template-config",                      type=Path, help="Config file (tokenizer_config.json) containing chat template(s)")
-    parser.add_argument("--remove-metadata",      action="append",     type=str,  help="Remove metadata (by key name) from output model")
+    parser.add_argument("--general-name",                              type=str,  help="The models general.name", metavar='"name"')
+    parser.add_argument("--general-description",                       type=str,  help="The models general.description", metavar='"Description ..."')
+    parser.add_argument("--chat-template",                             type=str,  help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."')
+    parser.add_argument("--chat-template-config",                      type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json')
+    parser.add_argument("--remove-metadata",      action="append",     type=str,  help="Remove metadata (by key name) from output model", metavar='general.url')
+    parser.add_argument("--special-token",        action="append",     type=str,  help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '"<token>"'))
+    parser.add_argument("--special-token-by-id",  action="append",     type=str,  help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0'))
     parser.add_argument("--force",                action="store_true",            help="Bypass warnings without confirmation")
     parser.add_argument("--verbose",              action="store_true",            help="Increase output verbosity")
     args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"])
@@ -133,20 +159,20 @@ def main() -> None:
     remove_metadata = args.remove_metadata or []
 
     if args.general_name:
-        new_metadata[gguf.Keys.General.NAME] = args.general_name
+        new_metadata[gguf.Keys.General.NAME] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_name)
 
     if args.general_description:
-        new_metadata[gguf.Keys.General.DESCRIPTION] = args.general_description
+        new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_description)
 
     if args.chat_template:
-        new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template
+        new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template)
 
     if args.chat_template_config:
         with open(args.chat_template_config, 'r') as fp:
             config = json.load(fp)
             template = config.get('chat_template')
             if template:
-                new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = template
+                new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
 
     if remove_metadata:
         logger.warning('*** Warning *** Warning *** Warning **')
@@ -166,6 +192,32 @@ def main() -> None:
     arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
     endianess = get_byteorder(reader)
 
+    token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []
+
+    for name, token in args.special_token or []:
+        if name not in token_names:
+            logger.warning(f'Unknown special token "{name}", ignoring...')
+        else:
+            ids = find_token(token_list, token)
+            new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, ids[0], f'= {token}')
+
+            if len(ids) > 1:
+                logger.warning(f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:')
+                logger.warning(', '.join(str(i) for i in ids))
+
+    for name, id_string in args.special_token_by_id or []:
+        if name not in token_names:
+            logger.warning(f'Unknown special token "{name}", ignoring...')
+        elif not id_string.isdecimal():
+            raise LookupError(f'Token ID "{id_string}" is not a valid ID!')
+        else:
+            id_int = int(id_string)
+
+            if id_int >= 0 and id_int < len(token_list):
+                new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, id_int, f'= {token_list[id_int]}')
+            else:
+                raise LookupError(f'Token ID {id_int} is not within token list!')
+
     if os.path.isfile(args.output) and not args.force:
         logger.warning('*** Warning *** Warning *** Warning **')
         logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')