]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support Llama 3 HF conversion (#6745)
authorPedro Cuenca <redacted>
Sun, 21 Apr 2024 11:50:41 +0000 (13:50 +0200)
committerGitHub <redacted>
Sun, 21 Apr 2024 11:50:41 +0000 (14:50 +0300)
* Support Llama 3 conversion

The tokenizer is BPE.

* style

* Accept suggestion

Co-authored-by: Sourab Mangrulkar <redacted>
* llama : add llama_token_is_eog()

ggml-ci

* llama : auto-detect more EOT tokens when missing in KV data

* convert : replacing EOS token is a hack

* llama : fix codegemma EOT token + add TODOs

* llama : fix model type string for 8B model

---------

Co-authored-by: Sourab Mangrulkar <redacted>
Co-authored-by: Georgi Gerganov <redacted>
20 files changed:
convert-hf-to-gguf.py
convert.py
examples/batched.swift/Sources/main.swift
examples/batched/batched.cpp
examples/beam-search/beam-search.cpp
examples/infill/infill.cpp
examples/llama.android/app/src/main/cpp/llama-android.cpp
examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
examples/llava/llava-cli.cpp
examples/lookahead/lookahead.cpp
examples/lookup/lookup.cpp
examples/main/main.cpp
examples/parallel/parallel.cpp
examples/passkey/passkey.cpp
examples/server/server.cpp
examples/server/utils.hpp
examples/simple/simple.cpp
examples/speculative/speculative.cpp
llama.cpp
llama.h

index 358dba8ed9d9010340263f273dac7e6c77d62218..4fd916cba3ed73cb0fbfbedd1bf8bc8e4b952d6b 100755 (executable)
@@ -1301,15 +1301,23 @@ class LlamaModel(Model):
         try:
             self. _set_vocab_sentencepiece()
         except FileNotFoundError:
-            self._set_vocab_llama_hf()
-
-        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
-                                          special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
-        special_vocab._set_special_token("prefix", 32007)
-        special_vocab._set_special_token("suffix", 32008)
-        special_vocab._set_special_token("middle", 32009)
-        special_vocab._set_special_token("eot",    32010)
-        special_vocab.add_to_gguf(self.gguf_writer)
+            try:
+                self._set_vocab_llama_hf()
+            except (FileNotFoundError, TypeError):
+                # Llama 3
+                self._set_vocab_gpt2()
+
+        # Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256)
+        if self.hparams.get("vocab_size", 32000) == 32016:
+            special_vocab = gguf.SpecialVocab(
+                self.dir_model, load_merges=False,
+                special_token_types = ['prefix', 'suffix', 'middle', 'eot']
+            )
+            special_vocab._set_special_token("prefix", 32007)
+            special_vocab._set_special_token("suffix", 32008)
+            special_vocab._set_special_token("middle", 32009)
+            special_vocab._set_special_token("eot",    32010)
+            special_vocab.add_to_gguf(self.gguf_writer)
 
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
@@ -2194,6 +2202,8 @@ class InternLM2Model(Model):
         old_eos = special_vocab.special_token_ids["eos"]
         if "chat" in os.path.basename(self.dir_model.absolute()):
             # For the chat model, we replace the eos with '<|im_end|>'.
+            # TODO: this is a hack, should be fixed
+            #       https://github.com/ggerganov/llama.cpp/pull/6745#issuecomment-2067687048
             special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
             print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
 in chat mode so that the conversation can end normally.")
@@ -2429,12 +2439,15 @@ class GemmaModel(Model):
 
     def set_vocab(self):
         self._set_vocab_sentencepiece()
+
+        # TODO: these special tokens should be exported only for the CodeGemma family
         special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
-                                          special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
+                                          special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot'])
         special_vocab._set_special_token("prefix", 67)
         special_vocab._set_special_token("suffix", 69)
         special_vocab._set_special_token("middle", 68)
-        special_vocab._set_special_token("eot",    70)
+        special_vocab._set_special_token("fsep",   70)
+        special_vocab._set_special_token("eot",    107)
         special_vocab.add_to_gguf(self.gguf_writer)
 
     def set_gguf_parameters(self):
@@ -2523,28 +2536,34 @@ class MambaModel(Model):
 
             field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
             self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]))
+
             field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
             self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
+
             field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
             self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
+
             field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
             self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
+
             field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
             self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
+
             field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
             self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
+
             field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
             self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
 
     def set_gguf_parameters(self):
-        d_model = self.find_hparam(["hidden_size", "d_model"])
-        d_conv  = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
+        d_model = self.find_hparam(["hidden_size",       "d_model"])
+        d_conv  = self.find_hparam(["conv_kernel",       "d_conv"],  optional=True) or 4
         d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
-        d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
+        d_state = self.find_hparam(["state_size",        "d_state"], optional=True) or 16
         # ceiling division
         # ref: https://stackoverflow.com/a/17511341/22827863
         # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
-        dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
+        dt_rank      = self.find_hparam(["time_step_rank",     "dt_rank"],      optional=True) or -(d_model // -16)
         rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
 
         # Fail early for models which don't have a block expansion factor of 2
index 24df0a4d84dc35b32c580addbc4e3a8ab11e4069..1c700cf6a3d65e578109df32acb340d11d081dd1 100755 (executable)
@@ -525,7 +525,14 @@ class LlamaHfVocab(Vocab):
 
         # pre-check so we know if we need transformers
         tokenizer_model: dict[str, Any] = tokenizer_json['model']
-        if (
+        is_llama3 = (
+            tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
+            and not tokenizer_model.get('byte_fallback', True)
+        )
+        if is_llama3:
+            raise TypeError('Llama 3 must be converted with BpeVocab')
+
+        if not is_llama3 and (
             tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
             or tokenizer_json['decoder']['type'] != 'Sequence'
         ):
index d75c503d58311ff66f1f1ee34a6bfcafe8c76dfe..5764acb6d5825c06a8f7e5bb9416aa34d0030783 100644 (file)
@@ -153,7 +153,7 @@ while n_cur <= n_len {
         // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
 
         // is it an end of stream? -> mark the stream as finished
-        if new_token_id == llama_token_eos(model) || n_cur == n_len {
+        if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
             i_batch[i] = -1
             // print("")
             if n_parallel > 1 {
index 7aaf63ceb1a7c7c7e36a43bcb18c4de98f17b661..be30d20bf81947c2d3d53c5a9abbf578c4d9e68e 100644 (file)
@@ -191,8 +191,8 @@ int main(int argc, char ** argv) {
 
             //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
 
-            // is it an end of stream? -> mark the stream as finished
-            if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
+            // is it an end of generation? -> mark the stream as finished
+            if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
                 i_batch[i] = -1;
                 LOG_TEE("\n");
                 if (n_parallel > 1) {
index 866c6d7a628674303789bb5f64f521c026072bbe..3d34378a506eba50a2ff9eebe0cbdc13766b1131 100644 (file)
@@ -47,7 +47,7 @@ struct beam_search_callback_data {
 // In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
 // For example, eob can be flagged due to maximum token length, stop words, etc.
 static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) {
-    return n_tokens && tokens[n_tokens-1] == llama_token_eos(llama_get_model(callback_data.ctx));
+    return n_tokens && llama_token_is_eog(llama_get_model(callback_data.ctx), tokens[n_tokens-1]);
 }
 
 // Function matching type llama_beam_search_callback_fn_t.
index c69dcd06e461fee9b1b59ffcf093ccfd18babdf4..afac145f63934c72a7b0dec240db859da9d1df79 100644 (file)
@@ -586,7 +586,7 @@ int main(int argc, char ** argv) {
 
             // deal with eot token in infill mode
             if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
-                if(is_interacting && !params.interactive_first) {
+                if (is_interacting && !params.interactive_first) {
                     // print an eot token
                     printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
                 }
@@ -651,8 +651,8 @@ int main(int argc, char ** argv) {
                 // LOG_TEE("took new input\n");
                 is_interacting = false;
             }
-            // deal with end of text token in interactive mode
-            else if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
+            // deal with end of generation tokens in interactive mode
+            else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
                 LOG("found EOS token\n");
 
                 if (params.interactive) {
@@ -731,8 +731,8 @@ int main(int argc, char ** argv) {
             }
         }
 
-        // end of text token
-        if (!embd.empty() && embd.back() == llama_token_eos(model) && !params.interactive) {
+        // end of generation
+        if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !params.interactive) {
             break;
         }
 
index ce8ab3b7094073542cabc0471f286751f9117df6..4af9de30383596dc0e7c804206b9349158c70a51 100644 (file)
@@ -408,7 +408,7 @@ Java_com_example_llama_Llm_completion_1loop(
     const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
 
     const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
-    if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
+    if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
         return env->NewStringUTF("");
     }
 
index c249291aea1107183889fba42b5fd31e9b8ff1a1..70c43a3852732e90adc8cab6520932c0bbc91053 100644 (file)
@@ -158,7 +158,7 @@ actor LlamaContext {
             new_token_id = llama_sample_token_greedy(context, &candidates_p)
         }
 
-        if new_token_id == llama_token_eos(model) || n_cur == n_len {
+        if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
             print("\n")
             let new_token_str = String(cString: temporary_invalid_cchars + [0])
             temporary_invalid_cchars.removeAll()
index 75948806ee5d4c878118f399a759b07b9f20d00c..50dac4caecfe21bc83d25c1709d96c689c037d1d 100644 (file)
@@ -45,7 +45,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
     const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
     llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
     static std::string ret;
-    if (id == llama_token_eos(llama_get_model(ctx_llama))) {
+    if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
         ret = "</s>";
     } else {
         ret = llama_token_to_piece(ctx_llama, id);
index 5af6a8ab6c92bc9422350dd3b443b206c72594ee..9c3540b2008c207678050f16013f691099cc87df 100644 (file)
@@ -299,7 +299,7 @@ int main(int argc, char ** argv) {
                 }
                 fflush(stdout);
 
-                if (id == llama_token_eos(model)) {
+                if (llama_token_is_eog(model, id)) {
                     has_eos = true;
                 }
 
index 65ed408a2758394f03508113f97319cd01d84ed0..9526e898fe7638218ea7577203bf65dec3133b43 100644 (file)
@@ -141,7 +141,7 @@ int main(int argc, char ** argv){
                 printf("%s", token_str.c_str());
             }
 
-            if (id == llama_token_eos(model)) {
+            if (llama_token_is_eog(model, id)) {
                 has_eos = true;
             }
 
index 249fc2bb605b36fc5f9184a7ce4475d4f8740ca7..1180734b9760d2cf4021fdba8fbc98bcf8a63ca3 100644 (file)
@@ -795,8 +795,8 @@ int main(int argc, char ** argv) {
                 }
             }
 
-            // deal with end of text token in interactive mode
-            if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
+            // deal with end of generation tokens in interactive mode
+            if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
                 LOG("found EOS token\n");
 
                 if (params.interactive) {
@@ -920,8 +920,8 @@ int main(int argc, char ** argv) {
             }
         }
 
-        // end of text token
-        if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) {
+        // end of generation
+        if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !(params.instruct || params.interactive || params.chatml)) {
             LOG_TEE(" [end of text]\n");
             break;
         }
index f66c91013eaeba1c0d8881e898953bce9399ad68..7c5595d6edb2dc2077a17a1bc7ee6f7241d08fcc 100644 (file)
@@ -359,7 +359,7 @@ int main(int argc, char ** argv) {
                 //        client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
 
                 if (client.n_decoded > 2 &&
-                        (id == llama_token_eos(model) ||
+                        (llama_token_is_eog(model, id) ||
                          (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
                          client.response.find("User:") != std::string::npos ||
                          client.response.find('\n') != std::string::npos)) {
index 2cbc9e1fa89ed1688d8e3e199e957282373bc0cb..f2ef9ca10d4a271f3357a3bc34843638f67aefae 100644 (file)
@@ -252,8 +252,8 @@ int main(int argc, char ** argv) {
             // sample the most likely token
             const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
 
-            // is it an end of stream?
-            if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
+            // is it an end of generation?
+            if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
                 LOG_TEE("\n");
 
                 break;
index 634e653ada2848b6af259da0f757432e44ea58f1..25bc29639677251df04869c72c1c27661e565c26 100644 (file)
@@ -1201,7 +1201,7 @@ struct server_context {
             });
         }
 
-        if (result.tok == llama_token_eos(model)) {
+        if (llama_token_is_eog(model, result.tok)) {
             slot.stopped_eos    = true;
             slot.has_next_token = false;
 
index a8d43ac63bf1142314ef0d3212af8f1276d0024a..1a22125028204816625da9f81a8bc5159f0dceba 100644 (file)
@@ -381,10 +381,6 @@ static json oaicompat_completion_params_parse(
     } else {
         llama_params["stop"] = json_value(body, "stop", json::array());
     }
-    // Some chat templates don't use EOS token to stop generation
-    // We must add their end sequences to list of stop words
-    llama_params["stop"].push_back("<|im_end|>"); // chatml
-    llama_params["stop"].push_back("<end_of_turn>"); // gemma
 
     // Handle "response_format" field
     if (body.contains("response_format")) {
index 39e2d8ea490e3ed70e75dd9bebcd7a98a9d125fb..b0f8e0fdc49873b76f0b4f73423d765b59a3e8db 100644 (file)
@@ -133,8 +133,8 @@ int main(int argc, char ** argv) {
             // sample the most likely token
             const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
 
-            // is it an end of stream?
-            if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
+            // is it an end of generation?
+            if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
                 LOG_TEE("\n");
 
                 break;
index 6a7367b0cde6b0b699fd9a807acf12ff09c84ecc..12e46fbc91a242677457f99943d528b476e801a6 100644 (file)
@@ -360,7 +360,7 @@ int main(int argc, char ** argv) {
                     }
                 }
 
-                if (token_id == llama_token_eos(model_tgt)) {
+                if (llama_token_is_eog(model_tgt, token_id)) {
                     has_eos = true;
                 }
                 ++n_predict;
index fa7c022f291304a2c1c817700c361c37d90fbcf6..8ca9650de09ada7f890b39dfb26abda16e35570a 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2120,7 +2120,7 @@ struct llama_vocab {
     id special_prefix_id = -1;
     id special_suffix_id = -1;
     id special_middle_id = -1;
-    id special_eot_id    = -1;
+    id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
 
     bool add_space_prefix = true;
 
@@ -3770,7 +3770,7 @@ static void llm_load_hparams(
                     switch (hparams.n_layer) {
                         case 22: model.type = e_model::MODEL_1B; break;
                         case 26: model.type = e_model::MODEL_3B; break;
-                        case 32: model.type = e_model::MODEL_7B; break;
+                        case 32: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_7B : e_model::MODEL_8B; break; // LLaMa 8B v3 uses GQA
                         case 40: model.type = e_model::MODEL_13B; break;
                         case 48: model.type = e_model::MODEL_34B; break;
                         case 60: model.type = e_model::MODEL_30B; break;
@@ -4179,7 +4179,10 @@ static void llm_load_vocab(
                     vocab.special_prefix_id = 67;
                     vocab.special_suffix_id = 69;
                     vocab.special_middle_id = 68;
-                    vocab.special_eot_id    = 70;
+                    // TODO: this is not EOT, it is "file separator" token, needs fix
+                    //       https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572
+                    //vocab.special_eot_id    = 70;
+                    vocab.special_eot_id    = 107;
                 }
             }
 
@@ -4308,6 +4311,7 @@ static void llm_load_vocab(
             { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
             { LLM_KV_TOKENIZER_EOT_ID,    vocab.special_eot_id    },
         };
+
         for (const auto & it : special_token_types) {
             const std::string & key = kv(std::get<0>(it));
             int32_t & id = std::get<1>(it);
@@ -4322,7 +4326,6 @@ static void llm_load_vocab(
             } else {
                 id = new_id;
             }
-
         }
 
         // Handle add_bos_token and add_eos_token
@@ -4336,6 +4339,27 @@ static void llm_load_vocab(
                 vocab.special_add_eos = int(temp);
             }
         }
+
+        // find EOT token: "<|eot_id|>", "<|im_emd|>", "<end_of_turn>", etc.
+        //
+        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
+        //       for now, we apply this workaround to find the EOT token based on its text
+        if (vocab.special_eot_id == -1) {
+            for (const auto & t : vocab.token_to_id) {
+                if (
+                        // TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
+                        //       need to fix convert script
+                        //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
+                        (t.first == "<|eot_id|>" ||
+                         t.first == "<|im_emd|>" ||
+                         t.first == "<end_of_turn>"
+                        )
+                   ) {
+                    vocab.special_eot_id = t.second;
+                    break;
+                }
+            }
+        }
     }
 
     // build special tokens cache
@@ -4498,14 +4522,19 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
 
     // special tokens
-    if (vocab.special_bos_id  != -1) { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,  vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
-    if (vocab.special_eos_id  != -1) { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,  vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
-    if (vocab.special_unk_id  != -1) { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,  vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
-    if (vocab.special_sep_id  != -1) { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,  vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
-    if (vocab.special_pad_id  != -1) { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,  vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
-    if (vocab.special_cls_id  != -1) { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,  vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
-    if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
-    if (vocab.linefeed_id     != -1) { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,     vocab.id_to_token[vocab.linefeed_id].text.c_str() );     }
+    if (vocab.special_bos_id    != -1) { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,  vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
+    if (vocab.special_eos_id    != -1) { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,  vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
+    if (vocab.special_unk_id    != -1) { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,  vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
+    if (vocab.special_sep_id    != -1) { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,  vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
+    if (vocab.special_pad_id    != -1) { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,  vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
+    if (vocab.special_cls_id    != -1) { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,  vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
+    if (vocab.special_mask_id   != -1) { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
+
+    if (vocab.linefeed_id       != -1) { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,       vocab.id_to_token[vocab.linefeed_id].text.c_str() );       }
+    if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token        = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); }
+    if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token        = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
+    if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
+    if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
 }
 
 // Returns false if cancelled by progress_callback
@@ -13268,16 +13297,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
     GGML_ASSERT(ctx);
     const int64_t t_start_sample_us = ggml_time_us();
 
-    bool allow_eos = false;
+    bool allow_eog = false;
     for (const auto & stack : grammar->stacks) {
         if (stack.empty()) {
-            allow_eos = true;
+            allow_eog = true;
             break;
         }
     }
 
-    const llama_token eos = llama_token_eos(&ctx->model);
-
     std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
     candidates_decoded.reserve(candidates->size);
     std::vector<llama_grammar_candidate>                              candidates_grammar;
@@ -13286,8 +13313,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
     for (size_t i = 0; i < candidates->size; ++i) {
         const llama_token id    = candidates->data[i].id;
         const std::string piece = llama_token_to_piece(ctx, id);
-        if (id == eos) {
-            if (!allow_eos) {
+        if (llama_token_is_eog(&ctx->model, id)) {
+            if (!allow_eog) {
                 candidates->data[i].logit = -INFINITY;
             }
         } else if (piece.empty() || piece[0] == 0) {
@@ -13476,7 +13503,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
 void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
     const int64_t t_start_sample_us = ggml_time_us();
 
-    if (token == llama_token_eos(&ctx->model)) {
+    if (llama_token_is_eog(&ctx->model, token)) {
         for (const auto & stack : grammar->stacks) {
             if (stack.empty()) {
                 return;
@@ -16880,6 +16907,13 @@ llama_token_type llama_token_get_type(const struct llama_model * model, llama_to
     return model->vocab.id_to_token[token].type;
 }
 
+bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
+    return token != -1 && (
+        token == llama_token_eos(model) ||
+        token == llama_token_eot(model)
+    );
+}
+
 llama_token llama_token_bos(const struct llama_model * model) {
     return model->vocab.special_bos_id;
 }
diff --git a/llama.h b/llama.h
index b5da686f7b7e5af5f88b6a2066064fec7276e91c..5bed97ad1ef9f70efda2043dc5b6948209b1f55d 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -783,6 +783,9 @@ extern "C" {
 
     LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
 
+    // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
+    LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
+
     // Special tokens
     LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
     LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
@@ -796,7 +799,7 @@ extern "C" {
     // Returns -1 if unknown, 1 for true or 0 for false.
     LLAMA_API int32_t         llama_add_eos_token(const struct llama_model * model);
 
-    // codellama infill tokens
+    // Codellama infill tokens
     LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
     LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
     LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix