if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed":
# ref: https://huggingface.co/tiiuae/falcon-7b
res = "falcon"
- if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e":
- # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base
- res = "falcon3"
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5
res = "bert-bge"
+ if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e":
+ # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base
+ res = "falcon3"
if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7":
# ref: https://huggingface.co/BAAI/bge-large-zh-v1.5
res = "bert-bge-large"
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
res = "jina-v2-code"
- if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b" or chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
- # ref: https://huggingface.co/THUDM/glm-4-9b-chat
- res = "chatglm-bpe"
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
# ref: https://huggingface.co/LumiOpen/Viking-7B
res = "viking"
if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450":
# ref: https://huggingface.co/facebook/chameleon-7b
res = "chameleon"
- if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
- # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
- res = "minerva-7b"
if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65":
# ref: https://huggingface.co/sentence-transformers/stsb-roberta-base
res = "roberta-bpe"
if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406":
# ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct
res = "llama4"
- if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
- # ref: https://huggingface.co/THUDM/glm-4-9b-hf
- res = "glm4"
if chkhsh == "0e9433cbbb161f89e264eb32e8e64bfe69e834973ffca5d41d3948a604a3e2a3":
# ref: https://huggingface.co/mistral-community/pixtral-12b
res = "pixtral"
if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec":
# ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base
res = "seed-coder"
+ if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
+ # ref: https://huggingface.co/THUDM/glm-4-9b-chat
+ res = "chatglm-bpe"
+ if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
+ # ref: https://huggingface.co/THUDM/glm-4-9b-chat
+ res = "chatglm-bpe"
+ if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
+ # ref: https://huggingface.co/THUDM/glm-4-9b-hf
+ res = "glm4"
+ if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
+ # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
+ res = "minerva-7b"
if res is None:
logger.warning("\n")
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-# This script downloads the tokenizer models of the specified models from Huggingface and
-# generates the get_vocab_base_pre() function for convert_hf_to_gguf.py
-#
-# This is necessary in order to analyze the type of pre-tokenizer used by the model and
-# provide the necessary information to llama.cpp via the GGUF header in order to implement
-# the same pre-tokenizer.
-#
-# ref: https://github.com/ggml-org/llama.cpp/pull/6920
-#
-# Instructions:
-#
-# - Add a new model to the "models" list
-# - Run the script with your huggingface token:
-#
-# python3 convert_hf_to_gguf_update.py <huggingface_token>
-#
-# - The convert_hf_to_gguf.py script will have had its get_vocab_base_pre() function updated
-# - Update llama.cpp with the new pre-tokenizer if necessary
-#
-# TODO: generate tokenizer tests for llama.cpp
-#
-
import logging
import os
import pathlib
import sys
import json
import shutil
+import argparse
from hashlib import sha256
from enum import IntEnum, auto
logger = logging.getLogger("convert_hf_to_gguf_update")
sess = requests.Session()
+convert_py_pth = pathlib.Path("convert_hf_to_gguf.py")
+convert_py = convert_py_pth.read_text(encoding="utf-8")
+hf_token_pth = pathlib.Path.home() / ".cache" / "huggingface" / "token"
+hf_token = hf_token_pth.read_text(encoding="utf-8").strip() if hf_token_pth.exists() else None
+
class TOKENIZER_TYPE(IntEnum):
SPM = auto()
UGM = auto()
+DOC_STRING = """
+This script downloads the tokenizer models of the specified models from Huggingface and
+generates the get_vocab_base_pre() function for convert_hf_to_gguf.py
+
+/!\\ It is intended to be used by contributors and is not meant to be run by end users
+
+This is necessary in order to analyze the type of pre-tokenizer used by the model and
+provide the necessary information to llama.cpp via the GGUF header in order to implement
+the same pre-tokenizer.
+
+ref: https://github.com/ggml-org/llama.cpp/pull/6920
+
+Instructions:
+
+- Add a new model to the "models" list
+- Run the script with your huggingface token
+ By default, token will be read from ~/.cache/huggingface/token
+- The convert_hf_to_gguf.py script will have had its get_vocab_base_pre() function updated
+- Update llama.cpp with the new pre-tokenizer if necessary
+"""
+# TODO: generate tokenizer tests for llama.cpp
+
+parser = argparse.ArgumentParser(description=DOC_STRING, formatter_class=argparse.RawTextHelpFormatter)
+parser.add_argument(
+ "--full", action="store_true",
+ help="download full list of models - make sure you have access to all of them",
+)
+parser.add_argument(
+ "hf_token",
+ help="optional HF token",
+ nargs="?",
+)
+args = parser.parse_args()
+hf_token = args.hf_token if args.hf_token is not None else hf_token
+
+if hf_token is None:
+ logger.error("HF token is required. Please provide it as an argument or set it in ~/.cache/huggingface/token")
+ sys.exit(1)
+
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
# will be updated with time - contributions welcome
CHK_TXT = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL'
-if len(sys.argv) == 2:
- token = sys.argv[1]
- if not token.startswith("hf_"):
- logger.info("Huggingface token seems invalid")
- logger.info("Usage: python convert_hf_to_gguf_update.py <huggingface_token>")
- sys.exit(1)
-else:
- logger.info("Usage: python convert_hf_to_gguf_update.py <huggingface_token>")
- sys.exit(1)
-
# TODO: add models here, base models preferred
models = [
{"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", },
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
{"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", },
{"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
- {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", },
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
{"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
{"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
{"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", },
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
- {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", },
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
]
+# some models are known to be broken upstream, so we will skip them as exceptions
+pre_computed_hashes = [
+ # chatglm-bpe has 2 hashes, why?
+ {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"},
+ {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
+ {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
+ {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
+]
+
def download_file_with_auth(url, token, save_path):
headers = {"Authorization": f"Bearer {token}"}
if os.path.isfile(save_path):
logger.info(f"{name}: File {save_path} already exists - skipping")
continue
- download_file_with_auth(f"{repo}/resolve/main/{file}", token, save_path)
+ download_file_with_auth(f"{repo}/resolve/main/{file}", hf_token, save_path)
+
+
+# get list of existing models and chkhsh from the convert_hf_to_gguf.py file
+# returns mapping res --> chkhsh
+def get_existing_models(convert_py):
+ pattern = r'if chkhsh == "([a-f0-9]{64})":\s*\n\s*.*\s*res = "([^"]+)"'
+ matches = re.findall(pattern, convert_py)
+ output = {}
+ for chkhsh, res in matches:
+ output[res] = chkhsh
+ return output
+
+existing_models = {}
+all_models = models.copy()
+if not args.full:
+ # Filter out models that already exist in convert_hf_to_gguf.py
+ existing_models = get_existing_models(convert_py)
+ all_models = models.copy()
+ models = [model for model in all_models if model["name"] not in existing_models]
+logging.info(f"Downloading {len(models)} models...")
for model in models:
try:
download_model(model)
# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function:
src_ifs = ""
-for model in models:
+for model in [*all_models, *pre_computed_hashes]:
name = model["name"]
tokt = model["tokt"]
+ chkhsh = model.get("chkhsh")
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
continue
continue
# create the tokenizer
- try:
- if name == "t5":
- tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
- else:
- tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
- except OSError as e:
- logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
- continue # Skip to the next model if the tokenizer can't be loaded
-
- chktok = tokenizer.encode(CHK_TXT)
- chkhsh = sha256(str(chktok).encode()).hexdigest()
-
- logger.info(f"model: {name}")
- logger.info(f"tokt: {tokt}")
- logger.info(f"repo: {model['repo']}")
- logger.info(f"chktok: {chktok}")
- logger.info(f"chkhsh: {chkhsh}")
-
- # print the "pre_tokenizer" content from the tokenizer.json
- with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
- cfg = json.load(f)
- normalizer = cfg["normalizer"]
- logger.info("normalizer: " + json.dumps(normalizer, indent=4))
- pre_tokenizer = cfg["pre_tokenizer"]
- logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
- if "ignore_merges" in cfg["model"]:
- logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4))
-
- logger.info("")
+ if chkhsh is not None:
+ # if the model has a pre-computed hash, use it
+ logger.info(f"Using pre-computed hash for model {name}: {chkhsh}")
+ elif name in existing_models:
+ # if the model already exists in convert_hf_to_gguf.py, skip compute hash
+ chkhsh = existing_models[name]
+ else:
+ # otherwise, compute the hash of the tokenizer
+ try:
+ logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...")
+ if name == "t5":
+ tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
+ except OSError as e:
+ logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
+ continue # Skip to the next model if the tokenizer can't be loaded
+
+ chktok = tokenizer.encode(CHK_TXT)
+ chkhsh = sha256(str(chktok).encode()).hexdigest()
+
+ logger.info(f"model: {name}")
+ logger.info(f"tokt: {tokt}")
+ logger.info(f"repo: {model['repo']}")
+ logger.info(f"chktok: {chktok}")
+ logger.info(f"chkhsh: {chkhsh}")
+
+ # print the "pre_tokenizer" content from the tokenizer.json
+ with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
+ cfg = json.load(f)
+ normalizer = cfg["normalizer"]
+ logger.info("normalizer: " + json.dumps(normalizer, indent=4))
+ pre_tokenizer = cfg["pre_tokenizer"]
+ logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
+ if "ignore_merges" in cfg["model"]:
+ logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4))
+
+ logger.info("")
src_ifs += f" if chkhsh == \"{chkhsh}\":\n"
src_ifs += f" # ref: {model['repo']}\n"
return res
"""
-convert_py_pth = pathlib.Path("convert_hf_to_gguf.py")
-convert_py = convert_py_pth.read_text(encoding="utf-8")
convert_py = re.sub(
r"(# Marker: Start get_vocab_base_pre)(.+?)( +# Marker: End get_vocab_base_pre)",
lambda m: m.group(1) + src_func + m.group(3),
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
continue # Skip this model and continue with the next one in the loop
+ if not os.path.exists(f"models/ggml-vocab-{name}.gguf"):
+ logger.info(f"Skip vocab files for model {name}, no GGUF file found")
+ continue
+
with open(f"models/ggml-vocab-{name}.gguf.inp", "w", encoding="utf-8") as f:
for text in tests:
f.write(f"{text}")