]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : keep track of all EOG tokens in the vocab (#9609)
authorGeorgi Gerganov <redacted>
Tue, 24 Sep 2024 07:16:06 +0000 (10:16 +0300)
committerGitHub <redacted>
Tue, 24 Sep 2024 07:16:06 +0000 (10:16 +0300)
ggml-ci

src/llama-vocab.cpp
src/llama-vocab.h
src/llama.cpp

index 2c007477e8da24341539bb8df5f2751ae112b624..a771eccda30172a3f2a9b23e4e96d4472524534b 100644 (file)
@@ -1570,11 +1570,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
 }
 
 bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
-    return token != -1 && (
-        token == llama_token_eos_impl(vocab) ||
-        token == llama_token_eot_impl(vocab) ||
-        token == llama_token_eom_impl(vocab)
-    );
+    return token != -1 && vocab.special_eog_ids.count(token) > 0;
 }
 
 bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
index dc4b5f12f7860030c2fc2a4f25e974597c0075ba..cc46f642bf1ae371ce5fa2b23aa9ed3b44f5895d 100644 (file)
@@ -6,6 +6,7 @@
 #include <vector>
 #include <unordered_map>
 #include <map>
+#include <set>
 
 struct llama_vocab {
     using id    = llama_token;
@@ -49,12 +50,15 @@ struct llama_vocab {
     id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
     id special_eom_id    = -1;
 
+    // set of all tokens that cause "end of generation"
+    std::set<id> special_eog_ids;
+
     // tokenizer flags
-    bool tokenizer_add_space_prefix = false;
-    bool tokenizer_add_bos          = false;
-    bool tokenizer_add_eos          = false;
-    bool tokenizer_ignore_merges    = false;
-    bool tokenizer_clean_spaces     = false;  // clean_up_tokenization_spaces
+    bool tokenizer_add_space_prefix           = false;
+    bool tokenizer_add_bos                    = false;
+    bool tokenizer_add_eos                    = false;
+    bool tokenizer_ignore_merges              = false;
+    bool tokenizer_clean_spaces               = false;  // clean_up_tokenization_spaces
     bool tokenizer_remove_extra_whitespaces   = false;
     bool tokenizer_escape_whitespaces         = true;
     bool tokenizer_treat_whitespace_as_suffix = false;
index c1ba2b3012127fdd9b192a57f8caab3e28924891..a718de054f934697aca987145cf21e26d29199e2 100644 (file)
@@ -6509,21 +6509,21 @@ static void llm_load_vocab(
         //       for now, we apply this workaround to find the EOT token based on its text
         if (vocab.special_eot_id == -1) {
             for (const auto & t : vocab.token_to_id) {
-                if (
+                if (false
                         // TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
                         //       need to fix convert script
                         //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
-                        (t.first == "<|eot_id|>" ||
-                         t.first == "<|im_end|>" ||
-                         t.first == "<|end|>" ||
-                         t.first == "<end_of_turn>" ||
-                         t.first == "<|endoftext|>"
-                        )
+                        || t.first == "<|eot_id|>"
+                        || t.first == "<|im_end|>"
+                        || t.first == "<|end|>"
+                        || t.first == "<end_of_turn>"
+                        || t.first == "<|endoftext|>"
+                        || t.first == "<EOT>"
                    ) {
                     vocab.special_eot_id = t.second;
                     if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                            __func__, t.first.c_str());
+                                __func__, t.first.c_str());
                         vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                     break;
@@ -6546,6 +6546,44 @@ static void llm_load_vocab(
                 }
             }
         }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        vocab.special_eog_ids.clear();
+        for (const auto & t : vocab.token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == "<end_of_turn>"
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == "<EOT>"
+               ) {
+                vocab.special_eog_ids.insert(t.second);
+                if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.first.c_str());
+                    vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            }
+        }
+
+        if (vocab.special_eos_id != -1 && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (vocab.special_eot_id != -1 && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (vocab.special_eom_id != -1 && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
     }
 
     // build special tokens cache
@@ -6749,6 +6787,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token        = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
     if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
     if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
+    if (vocab.special_eom_id    != -1) { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, vocab.special_eom_id,    vocab.id_to_token[vocab.special_eom_id].text.c_str() );    }
+
+    for (const auto & id : vocab.special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
+    }
 
     LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);