]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gguf : add special tokens metadata for FIM/Infill (#6689)
authorDaniel Bevenius <redacted>
Tue, 16 Apr 2024 06:13:13 +0000 (08:13 +0200)
committerGitHub <redacted>
Tue, 16 Apr 2024 06:13:13 +0000 (09:13 +0300)
This commit adds special token metadata for Fill-In-the-Middle
(FIM)/Infill to the GGUF model.

The motivation for this is that currently there is support for CodeLlama
but other models exist now like CodeGemma, but the different models use
different token ids for the special tokens and this commit allows for
supporting multiple models.

Signed-off-by: Daniel Bevenius <redacted>
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
llama.cpp

index b51d68307531610480b5ec1142da0f27777d8e60..6d28ab5e4919cb4f39c2449f619c55c2016af85b 100755 (executable)
@@ -1221,6 +1221,14 @@ class LlamaModel(Model):
         except FileNotFoundError:
             self._set_vocab_llama_hf()
 
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
+                                          special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
+        special_vocab._set_special_token("prefix", 32007)
+        special_vocab._set_special_token("suffix", 32008)
+        special_vocab._set_special_token("middle", 32009)
+        special_vocab._set_special_token("eot",    32010)
+        special_vocab.add_to_gguf(self.gguf_writer)
+
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
         hparams = self.hparams
@@ -2240,6 +2248,13 @@ class GemmaModel(Model):
 
     def set_vocab(self):
         self._set_vocab_sentencepiece()
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
+                                          special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
+        special_vocab._set_special_token("prefix", 67)
+        special_vocab._set_special_token("suffix", 69)
+        special_vocab._set_special_token("middle", 68)
+        special_vocab._set_special_token("eot",    70)
+        special_vocab.add_to_gguf(self.gguf_writer)
 
     def set_gguf_parameters(self):
         hparams = self.hparams
index 2566b2fb8f2966084dba6189f9783ba3528349ac..1358206a3e5691e175c63709a322552f521eb2a9 100644 (file)
@@ -90,6 +90,11 @@ class Keys:
         HF_JSON          = "tokenizer.huggingface.json"
         RWKV             = "tokenizer.rwkv.world"
         CHAT_TEMPLATE    = "tokenizer.chat_template"
+        # FIM/Infill special tokens constants
+        PREFIX_ID        = "tokenizer.ggml.prefix_token_id"
+        SUFFIX_ID        = "tokenizer.ggml.suffix_token_id"
+        MIDDLE_ID        = "tokenizer.ggml.middle_token_id"
+        EOT_ID           = "tokenizer.ggml.eot_token_id"
 
 
 #
@@ -885,3 +890,7 @@ KEY_TOKENIZER_CLS_ID     = Keys.Tokenizer.CLS_ID
 KEY_TOKENIZER_MASK_ID    = Keys.Tokenizer.MASK_ID
 KEY_TOKENIZER_HF_JSON    = Keys.Tokenizer.HF_JSON
 KEY_TOKENIZER_RWKV       = Keys.Tokenizer.RWKV
+KEY_TOKENIZER_PRIFIX_ID  = Keys.Tokenizer.PREFIX_ID
+KEY_TOKENIZER_SUFFIX_ID  = Keys.Tokenizer.SUFFIX_ID
+KEY_TOKENIZER_MIDDLE_ID  = Keys.Tokenizer.MIDDLE_ID
+KEY_TOKENIZER_EOT_ID     = Keys.Tokenizer.EOT_ID
index f4c4407667bda04cace93408f19e017fe0ef28da..ff9326d59c717c8b9fe68daccf6fa3b33f2f9129 100644 (file)
@@ -469,6 +469,18 @@ class GGUFWriter:
     def add_chat_template(self, value: str) -> None:
         self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
 
+    def add_prefix_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.PREFIX_ID, id)
+
+    def add_suffix_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.SUFFIX_ID, id)
+
+    def add_middle_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.MIDDLE_ID, id)
+
+    def add_eot_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.EOT_ID, id)
+
     def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
         pack_prefix = ''
         if not skip_pack_prefix:
index a5ef2fd8fa57597276af0dd20bc47b5e85730ba3..38e5936254e4a40c156818d28a6a827728a15b52 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -327,6 +327,10 @@ enum llm_kv {
     LLM_KV_TOKENIZER_ADD_PREFIX,
     LLM_KV_TOKENIZER_HF_JSON,
     LLM_KV_TOKENIZER_RWKV,
+    LLM_KV_TOKENIZER_PREFIX_ID,
+    LLM_KV_TOKENIZER_SUFFIX_ID,
+    LLM_KV_TOKENIZER_MIDDLE_ID,
+    LLM_KV_TOKENIZER_EOT_ID,
 };
 
 static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
@@ -399,6 +403,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_ADD_PREFIX,          "tokenizer.ggml.add_space_prefix"   },
     { LLM_KV_TOKENIZER_HF_JSON,             "tokenizer.huggingface.json"        },
     { LLM_KV_TOKENIZER_RWKV,                "tokenizer.rwkv.world"              },
+    { LLM_KV_TOKENIZER_PREFIX_ID,           "tokenizer.ggml.prefix_token_id"    },
+    { LLM_KV_TOKENIZER_SUFFIX_ID,           "tokenizer.ggml.suffix_token_id"    },
+    { LLM_KV_TOKENIZER_MIDDLE_ID,           "tokenizer.ggml.middle_token_id"    },
+    { LLM_KV_TOKENIZER_EOT_ID,              "tokenizer.ggml.eot_token_id"       },
 };
 
 struct LLM_KV {
@@ -2055,10 +2063,10 @@ struct llama_vocab {
     int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
 
     id linefeed_id       = 13;
-    id special_prefix_id = 32007;
-    id special_middle_id = 32009;
-    id special_suffix_id = 32008;
-    id special_eot_id    = 32010;
+    id special_prefix_id = -1;
+    id special_suffix_id = -1;
+    id special_middle_id = -1;
+    id special_eot_id    = -1;
 
     bool add_space_prefix = true;
 
@@ -4072,6 +4080,30 @@ static void llm_load_vocab(
             vocab.special_cls_id  = -1;
             vocab.special_mask_id = -1;
 
+            // For Fill-In-the-Middle (FIM)/infill models which where converted
+            // prior to support of FIM special tokens in GGUF, the following
+            // will allow those models to continue to work. The general names
+            // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and
+            // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once
+            // new versions of these models have been published.
+            std::string gen_name;
+            ml.get_key(LLM_KV_GENERAL_NAME, gen_name);
+            std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(),
+                [](unsigned char c){ return std::tolower(c); });
+            if (gen_name.find("code") != std::string::npos) {
+                if (model.arch == LLM_ARCH_LLAMA) {
+                    vocab.special_prefix_id = 32007;
+                    vocab.special_suffix_id = 32008;
+                    vocab.special_middle_id = 32009;
+                    vocab.special_eot_id    = 32010;
+                } else if (model.arch == LLM_ARCH_GEMMA) {
+                    vocab.special_prefix_id = 67;
+                    vocab.special_suffix_id = 69;
+                    vocab.special_middle_id = 68;
+                    vocab.special_eot_id    = 70;
+                }
+            }
+
             const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
             if (add_space_prefix_keyidx != -1) {
                 vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
@@ -4185,13 +4217,17 @@ static void llm_load_vocab(
     // special tokens
     {
         const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
-            { LLM_KV_TOKENIZER_BOS_ID,  vocab.special_bos_id  },
-            { LLM_KV_TOKENIZER_EOS_ID,  vocab.special_eos_id  },
-            { LLM_KV_TOKENIZER_UNK_ID,  vocab.special_unk_id  },
-            { LLM_KV_TOKENIZER_SEP_ID,  vocab.special_sep_id  },
-            { LLM_KV_TOKENIZER_PAD_ID,  vocab.special_pad_id  },
-            { LLM_KV_TOKENIZER_CLS_ID,  vocab.special_cls_id  },
-            { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
+            { LLM_KV_TOKENIZER_BOS_ID,    vocab.special_bos_id    },
+            { LLM_KV_TOKENIZER_EOS_ID,    vocab.special_eos_id    },
+            { LLM_KV_TOKENIZER_UNK_ID,    vocab.special_unk_id    },
+            { LLM_KV_TOKENIZER_SEP_ID,    vocab.special_sep_id    },
+            { LLM_KV_TOKENIZER_PAD_ID,    vocab.special_pad_id    },
+            { LLM_KV_TOKENIZER_CLS_ID,    vocab.special_cls_id    },
+            { LLM_KV_TOKENIZER_MASK_ID,   vocab.special_mask_id   },
+            { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
+            { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
+            { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
+            { LLM_KV_TOKENIZER_EOT_ID,    vocab.special_eot_id    },
         };
         for (const auto & it : special_token_types) {
             const std::string & key = kv(std::get<0>(it));