]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support models without vocabulary (#5798)
authorMichael Podvitskiy <redacted>
Thu, 14 Mar 2024 16:21:56 +0000 (17:21 +0100)
committerGitHub <redacted>
Thu, 14 Mar 2024 16:21:56 +0000 (18:21 +0200)
* additional methods to read model and ctx parameters

* vocab size as a part of a model metadata

* models without vocabulary, convert.py part

* models without vocabulary, llama.cpp part

* PR clean up

* converter scrypt fixes

* llama_vocab_type update (renamed the new key)

* pr review fixes

* revert function renaming

* one more NoVocab assert

convert.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
llama.cpp
llama.h

index c15f8c47ea4f7fd658d429002628b5b667cd1691..161430f3e717e144fac734fcf9549e5dd018c9ac 100755 (executable)
@@ -332,6 +332,9 @@ class Params:
 #
 
 class BpeVocab:
+    tokenizer_model = "gpt2"
+    name = "bpe"
+
     def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
         self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
         if isinstance(self.bpe_tokenizer.get('model'), dict):
@@ -390,6 +393,9 @@ class BpeVocab:
 
 
 class SentencePieceVocab:
+    tokenizer_model = "llama"
+    name = "spm"
+
     def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
         self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
         added_tokens: dict[str, int]
@@ -453,6 +459,9 @@ class SentencePieceVocab:
 
 
 class HfVocab:
+    tokenizer_model = "llama"
+    name = "hfft"
+
     def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None:
         try:
             from transformers import AutoTokenizer
@@ -553,7 +562,15 @@ class HfVocab:
         return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
 
 
-Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab"
+class NoVocab:
+    tokenizer_model = "no_vocab"
+    name = "no_vocab"
+
+    def __repr__(self) -> str:
+        return "<NoVocab for a model without integrated vocabulary>"
+
+
+Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab"
 
 
 #
@@ -935,8 +952,10 @@ def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> N
     # Handle special case where the model's vocab size is not set
     if params.n_vocab == -1:
         raise ValueError(
-            f"The model's vocab size is set to -1 in params.json. Please update it manually. Maybe {vocab.vocab_size}?"
+            f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}"
         )
+    if isinstance(vocab, NoVocab):
+        return  # model has no vocab
 
     # Check for a vocab size mismatch
     if params.n_vocab == vocab.vocab_size:
@@ -977,6 +996,7 @@ class OutputFile:
             name = str(params.path_model.parent).split('/')[-1]
 
         self.gguf.add_name                (name)
+        self.gguf.add_vocab_size          (params.n_vocab)
         self.gguf.add_context_length      (params.n_ctx)
         self.gguf.add_embedding_length    (params.n_embd)
         self.gguf.add_block_count         (params.n_layer)
@@ -1013,21 +1033,9 @@ class OutputFile:
         if params.ftype is not None:
             self.gguf.add_file_type(params.ftype)
 
-    def handle_tokenizer_model(self, vocab: Vocab) -> str:
-        # Map the vocab types to the supported tokenizer models
-        tokenizer_model = {
-            SentencePieceVocab: "llama",
-            HfVocab: "llama",
-            BpeVocab: "gpt2",
-        }.get(type(vocab))
-
-        # Block if vocab type is not predefined
-        if tokenizer_model is None:
-            raise ValueError("Unknown vocab type: Not supported")
-
-        return tokenizer_model
-
     def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
+        assert not isinstance(vocab, NoVocab)
+
         tokens = []
         scores = []
         toktypes = []
@@ -1043,11 +1051,8 @@ class OutputFile:
         return tokens, scores, toktypes
 
     def add_meta_vocab(self, vocab: Vocab) -> None:
-        # Handle the tokenizer model
-        tokenizer_model = self.handle_tokenizer_model(vocab)
-
         # Ensure that tokenizer_model is added to the GGUF model
-        self.gguf.add_tokenizer_model(tokenizer_model)
+        self.gguf.add_tokenizer_model(vocab.tokenizer_model)
 
         # Extract model vocabulary for model conversion
         tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab)
@@ -1074,6 +1079,26 @@ class OutputFile:
     def write_tensor_info(self) -> None:
         self.gguf.write_ti_data_to_file()
 
+    def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None:
+        ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency)
+        if ftype == GGMLFileType.MostlyQ8_0:
+            ndarrays = bounded_parallel_map(
+                OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency,
+                use_processpool_executor=True,
+            )
+        else:
+            ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
+
+        start = time.time()
+        for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
+            elapsed = time.time() - start
+            size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
+            padi = len(str(len(model)))
+            print(
+                f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
+            )
+            self.gguf.write_tensor_data(ndarray)
+
     def close(self) -> None:
         self.gguf.close()
 
@@ -1082,7 +1107,7 @@ class OutputFile:
         fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
         endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False,
     ) -> None:
-        check_vocab_size(params, vocab, pad_vocab = pad_vocab)
+        check_vocab_size(params, vocab, pad_vocab=pad_vocab)
 
         of = OutputFile(fname_out, endianess=endianess)
 
@@ -1120,8 +1145,11 @@ class OutputFile:
 
         # meta data
         of.add_meta_arch(params)
-        of.add_meta_vocab(vocab)
-        of.add_meta_special_vocab(svocab)
+        if isinstance(vocab, NoVocab):
+            of.gguf.add_tokenizer_model(vocab.tokenizer_model)
+        else:
+            of.add_meta_vocab(vocab)
+            of.add_meta_special_vocab(svocab)
 
         # tensor info
         for name, lazy_tensor in model.items():
@@ -1131,24 +1159,7 @@ class OutputFile:
         of.write_tensor_info()
 
         # tensor data
-        ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
-        if ftype == GGMLFileType.MostlyQ8_0:
-            ndarrays = bounded_parallel_map(
-                OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency,
-                use_processpool_executor=True,
-            )
-        else:
-            ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
-
-        start = time.time()
-        for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
-            elapsed = time.time() - start
-            size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
-            padi = len(str(len(model)))
-            print(
-                f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
-            )
-            of.gguf.write_tensor_data(ndarray)
+        of.write_tensor_data(ftype, model, concurrency)
 
         of.close()
 
@@ -1309,8 +1320,8 @@ class VocabFactory:
                 return vtype, path
         raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
 
-    def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: Path) -> gguf.SpecialVocab:
-        load_merges = vocabtype == "bpe"
+    def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab:
+        load_merges = vocab.name == "bpe"
         n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None
         return gguf.SpecialVocab(
             model_parent_path,
@@ -1319,30 +1330,34 @@ class VocabFactory:
             n_vocab=n_vocab,
         )
 
-    def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
+    def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab:
         vocab_type, path = self._select_file(vocab_types)
         print(f"Loading vocab file {path!r}, type {vocab_type!r}")
 
         added_tokens_path = path.parent / "added_tokens.json"
-        vocab: Vocab
         if vocab_type == "bpe":
-            vocab = BpeVocab(
+            return BpeVocab(
                 path, added_tokens_path if added_tokens_path.exists() else None
             )
-        elif vocab_type == "spm":
-            vocab = SentencePieceVocab(
+        if vocab_type == "spm":
+            return SentencePieceVocab(
                 path, added_tokens_path if added_tokens_path.exists() else None
             )
-        elif vocab_type == "hfft":
-            vocab = HfVocab(
+        if vocab_type == "hfft":
+            return HfVocab(
                 path.parent, added_tokens_path if added_tokens_path.exists() else None
             )
+        raise ValueError(vocab_type)
+
+    def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
+        vocab: Vocab
+        if len(vocab_types) == 1 and "no_vocab" in vocab_types:
+            vocab = NoVocab()
         else:
-            raise ValueError(vocab_type)
+            vocab = self._create_vocab_by_path(vocab_types)
         # FIXME: Respect --vocab-dir?
         special_vocab = self._create_special_vocab(
             vocab,
-            vocab_type,
             model_parent_path,
         )
         return vocab, special_vocab
@@ -1380,6 +1395,7 @@ def main(args_in: list[str] | None = None) -> None:
     parser.add_argument("--dump",         action="store_true",    help="don't convert, just show what's in the model")
     parser.add_argument("--dump-single",  action="store_true",    help="don't convert, just show what's in a single model file")
     parser.add_argument("--vocab-only",   action="store_true",    help="extract only the vocab")
+    parser.add_argument("--no-vocab",     action="store_true",    help="store model without the vocab")
     parser.add_argument("--outtype",      choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
     parser.add_argument("--vocab-dir",    type=Path,              help="directory containing tokenizer.model, if separate from model file")
     parser.add_argument("--vocab-type",                           help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft")
@@ -1392,6 +1408,10 @@ def main(args_in: list[str] | None = None) -> None:
     parser.add_argument("--skip-unknown", action="store_true",    help="skip unknown tensor names instead of failing")
 
     args = parser.parse_args(args_in)
+    if args.no_vocab:
+        if args.vocab_only:
+            raise ValueError("no need to specify --vocab-only if using --no-vocab")
+        args.vocab_type = "no_vocab"
 
     if args.dump_single:
         model_plus = lazy_load_file(args.model)
@@ -1442,7 +1462,7 @@ def main(args_in: list[str] | None = None) -> None:
         print(f"Wrote {outfile}")
         return
 
-    if model_plus.vocab is not None and args.vocab_dir is None:
+    if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
         vocab = model_plus.vocab
 
     print(f"Vocab info: {vocab}")
index 99f71f0a1aa29b7d549be7ba63f904be3da9ffe4..2d7cf16c14ed1f86ba92b5441bef9de0c3c95448 100644 (file)
@@ -32,6 +32,7 @@ class Keys:
         FILE_TYPE            = "general.file_type"
 
     class LLM:
+        VOCAB_SIZE            = "{arch}.vocab_size"
         CONTEXT_LENGTH        = "{arch}.context_length"
         EMBEDDING_LENGTH      = "{arch}.embedding_length"
         BLOCK_COUNT           = "{arch}.block_count"
@@ -752,6 +753,7 @@ KEY_GENERAL_SOURCE_HF_REPO       = Keys.General.SOURCE_HF_REPO
 KEY_GENERAL_FILE_TYPE            = Keys.General.FILE_TYPE
 
 # LLM
+KEY_VOCAB_SIZE            = Keys.LLM.VOCAB_SIZE
 KEY_CONTEXT_LENGTH        = Keys.LLM.CONTEXT_LENGTH
 KEY_EMBEDDING_LENGTH      = Keys.LLM.EMBEDDING_LENGTH
 KEY_BLOCK_COUNT           = Keys.LLM.BLOCK_COUNT
index 4d389be951d721031b5357002d5dc4d72a6c761d..81b2eb884d4854cca98915fb94033324ba754429 100644 (file)
@@ -321,6 +321,9 @@ class GGUFWriter:
         self.data_alignment = alignment
         self.add_uint32(Keys.General.ALIGNMENT, alignment)
 
+    def add_vocab_size(self, size: int) -> None:
+        self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
+
     def add_context_length(self, length: int) -> None:
         self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
 
index eb48b1e90c8d1cbdc3ed14384f82cf5c3300dce6..10fd53469eb6fc10480eb11f462981eb8524f451 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -258,6 +258,7 @@ enum llm_kv {
     LLM_KV_GENERAL_SOURCE_URL,
     LLM_KV_GENERAL_SOURCE_HF_REPO,
 
+    LLM_KV_VOCAB_SIZE,
     LLM_KV_CONTEXT_LENGTH,
     LLM_KV_EMBEDDING_LENGTH,
     LLM_KV_BLOCK_COUNT,
@@ -321,6 +322,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_GENERAL_SOURCE_URL,            "general.source.url"                    },
     { LLM_KV_GENERAL_SOURCE_HF_REPO,        "general.source.huggingface.repository" },
 
+    { LLM_KV_VOCAB_SIZE,                    "%s.vocab_size"            },
     { LLM_KV_CONTEXT_LENGTH,                "%s.context_length"        },
     { LLM_KV_EMBEDDING_LENGTH,              "%s.embedding_length"      },
     { LLM_KV_BLOCK_COUNT,                   "%s.block_count"           },
@@ -3242,10 +3244,11 @@ static const char * llama_model_type_name(e_model type) {
 
 static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
     switch (type) {
-        case LLAMA_VOCAB_TYPE_SPM: return "SPM";
-        case LLAMA_VOCAB_TYPE_BPE: return "BPE";
-        case LLAMA_VOCAB_TYPE_WPM: return "WPM";
-        default:                   return "unknown";
+        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
+        default:                    return "unknown";
     }
 }
 
@@ -3277,14 +3280,14 @@ static void llm_load_hparams(
     ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
 
     // get hparams kv
-    ml.get_arr_n(LLM_KV_TOKENIZER_LIST,       hparams.n_vocab);
-    ml.get_key  (LLM_KV_CONTEXT_LENGTH,       hparams.n_ctx_train);
-    ml.get_key  (LLM_KV_EMBEDDING_LENGTH,     hparams.n_embd);
-    ml.get_key  (LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff);
-    ml.get_key  (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
-    ml.get_key  (LLM_KV_BLOCK_COUNT,          hparams.n_layer);
-    ml.get_key  (LLM_KV_EXPERT_COUNT,         hparams.n_expert,      false);
-    ml.get_key  (LLM_KV_EXPERT_USED_COUNT,    hparams.n_expert_used, false);
+    ml.get_key(LLM_KV_VOCAB_SIZE,           hparams.n_vocab,       false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
+    ml.get_key(LLM_KV_CONTEXT_LENGTH,       hparams.n_ctx_train);
+    ml.get_key(LLM_KV_EMBEDDING_LENGTH,     hparams.n_embd);
+    ml.get_key(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff);
+    ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
+    ml.get_key(LLM_KV_BLOCK_COUNT,          hparams.n_layer);
+    ml.get_key(LLM_KV_EXPERT_COUNT,         hparams.n_expert,      false);
+    ml.get_key(LLM_KV_EXPERT_USED_COUNT,    hparams.n_expert_used, false);
 
     GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
     GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
@@ -3645,30 +3648,25 @@ static void llm_load_vocab(
 
     const auto kv = LLM_KV(model.arch);
 
-    const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
-    if (token_idx == -1) {
-        throw std::runtime_error("cannot find tokenizer vocab in model file\n");
-    }
-
-    const float * scores = nullptr;
-    const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
-    if (score_idx != -1) {
-        scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
-    }
-
-    const int * toktypes = nullptr;
-    const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
-    if (toktype_idx != -1) {
-        toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
-    }
-
     // determine vocab type
     {
         std::string tokenizer_name;
 
         ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name);
 
-        if (tokenizer_name == "llama") {
+        if (tokenizer_name == "no_vocab") {
+            vocab.type = LLAMA_VOCAB_TYPE_NONE;
+
+            // default special tokens
+            vocab.special_bos_id = -1;
+            vocab.special_eos_id = -1;
+            vocab.special_unk_id = -1;
+            vocab.special_sep_id = -1;
+            vocab.special_pad_id = -1;
+            vocab.linefeed_id    = -1;
+
+            return;
+        } else if (tokenizer_name == "llama") {
             vocab.type = LLAMA_VOCAB_TYPE_SPM;
 
             // default special tokens
@@ -3734,6 +3732,23 @@ static void llm_load_vocab(
         }
     }
 
+    const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
+    if (token_idx == -1) {
+        throw std::runtime_error("cannot find tokenizer vocab in model file\n");
+    }
+
+    const float * scores = nullptr;
+    const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
+    if (score_idx != -1) {
+        scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
+    }
+
+    const int * toktypes = nullptr;
+    const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
+    if (toktype_idx != -1) {
+        toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
+    }
+
     const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
 
     vocab.id_to_token.resize(n_vocab);
@@ -5023,7 +5038,8 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
 
         llm_load_print_meta(ml, model);
 
-        if (model.hparams.n_vocab != model.vocab.id_to_token.size()) {
+        if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
+            model.hparams.n_vocab != model.vocab.id_to_token.size()) {
             throw std::runtime_error("vocab size mismatch");
         }
 
@@ -9361,26 +9377,32 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
 }
 
 static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
     return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL;
 }
 
 static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
     return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN;
 }
 
 static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
     return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL;
 }
 
 static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
     return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
 }
 
 static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
+    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
     return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
 }
 
 static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
+    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
     GGML_ASSERT(llama_is_byte_token(vocab, id));
     const auto& token_data = vocab.id_to_token.at(id);
     switch (llama_vocab_get_type(vocab)) {
@@ -9401,6 +9423,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
 }
 
 static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
+    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
     static const char * hex = "0123456789ABCDEF";
     switch (llama_vocab_get_type(vocab)) {
         case LLAMA_VOCAB_TYPE_SPM: {
@@ -10232,6 +10255,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                     }
                 }
             } break;
+        case LLAMA_VOCAB_TYPE_NONE:
+            GGML_ASSERT(false);
     }
 
     return output;
@@ -13138,7 +13163,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
 }
 
 int32_t llama_n_vocab(const struct llama_model * model) {
-    return model->vocab.id_to_token.size();
+    return model->hparams.n_vocab;
 }
 
 int32_t llama_n_ctx_train(const struct llama_model * model) {
@@ -13962,14 +13987,17 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
 }
 
 const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
+    GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
     return model->vocab.id_to_token[token].text.c_str();
 }
 
 float llama_token_get_score(const struct llama_model * model, llama_token token) {
+    GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
     return model->vocab.id_to_token[token].score;
 }
 
 llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
+    GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
     return model->vocab.id_to_token[token].type;
 }
 
diff --git a/llama.h b/llama.h
index 2d16cc9b9fa2c2b9c6847986f386bb8aceaa025f..90aa5372e740b264a76169202228ac58cc0029b8 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -59,9 +59,10 @@ extern "C" {
     typedef int32_t llama_seq_id;
 
     enum llama_vocab_type {
-        LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
-        LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
-        LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
+        LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
+        LLAMA_VOCAB_TYPE_SPM  = 1, // SentencePiece
+        LLAMA_VOCAB_TYPE_BPE  = 2, // Byte Pair Encoding
+        LLAMA_VOCAB_TYPE_WPM  = 3, // WordPiece
     };
 
     // note: these values should be synchronized with ggml_rope