]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : remove notion of CLS token (#11064)
authorGeorgi Gerganov <redacted>
Sun, 12 Jan 2025 10:15:53 +0000 (12:15 +0200)
committerGitHub <redacted>
Sun, 12 Jan 2025 10:15:53 +0000 (12:15 +0200)
ggml-ci

gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
include/llama.h
src/llama-vocab.cpp
src/llama-vocab.h

index 56aa9288dad40cbdb25fe8022381cff8022eb08a..8fe84df21ea2093dfdbd304e5abdecbd2d69b1f1 100644 (file)
@@ -184,7 +184,6 @@ class Keys:
         UNK_ID               = "tokenizer.ggml.unknown_token_id"
         SEP_ID               = "tokenizer.ggml.seperator_token_id"
         PAD_ID               = "tokenizer.ggml.padding_token_id"
-        CLS_ID               = "tokenizer.ggml.cls_token_id"
         MASK_ID              = "tokenizer.ggml.mask_token_id"
         ADD_BOS              = "tokenizer.ggml.add_bos_token"
         ADD_EOS              = "tokenizer.ggml.add_eos_token"
@@ -1837,7 +1836,6 @@ KEY_TOKENIZER_EOM_ID     = Keys.Tokenizer.EOM_ID
 KEY_TOKENIZER_UNK_ID     = Keys.Tokenizer.UNK_ID
 KEY_TOKENIZER_SEP_ID     = Keys.Tokenizer.SEP_ID
 KEY_TOKENIZER_PAD_ID     = Keys.Tokenizer.PAD_ID
-KEY_TOKENIZER_CLS_ID     = Keys.Tokenizer.CLS_ID
 KEY_TOKENIZER_MASK_ID    = Keys.Tokenizer.MASK_ID
 KEY_TOKENIZER_HF_JSON    = Keys.Tokenizer.HF_JSON
 KEY_TOKENIZER_RWKV       = Keys.Tokenizer.RWKV
index bf851c92ca5483d961810e09cb22fdee211c1d6d..080d2b9dce5cb36dc0ee64ce19f53e06690be8ad 100644 (file)
@@ -857,9 +857,6 @@ class GGUFWriter:
     def add_pad_token_id(self, id: int) -> None:
         self.add_uint32(Keys.Tokenizer.PAD_ID, id)
 
-    def add_cls_token_id(self, id: int) -> None:
-        self.add_uint32(Keys.Tokenizer.CLS_ID, id)
-
     def add_mask_token_id(self, id: int) -> None:
         self.add_uint32(Keys.Tokenizer.MASK_ID, id)
 
index 9f04bc62202ea574750935524105f5ef54d7dc84..a184884c77a51d67b3c824e922329b66130b3a93 100644 (file)
@@ -937,7 +937,6 @@ extern "C" {
     LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence
     LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence
     LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn
-    LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab); // classification
     LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
     LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
     LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
@@ -973,6 +972,10 @@ extern "C" {
     DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead");
     DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead");
 
+    // CLS is equivalent to BOS
+    DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification
+            "use llama_vocab_bos instead");
+
     //
     // Tokenization
     //
index ed87517376e4d92891a867407403df77c81b2bab..d0fb85cea3127caccec2aefd63fdc91133d355b3 100644 (file)
@@ -1218,7 +1218,6 @@ struct llama_vocab::impl {
     llama_token special_unk_id  = 0;
     llama_token special_sep_id  = LLAMA_TOKEN_NULL;
     llama_token special_pad_id  = LLAMA_TOKEN_NULL;
-    llama_token special_cls_id  = LLAMA_TOKEN_NULL; // TODO: revisit if this is really needed https://github.com/ggerganov/llama.cpp/pull/10930
     llama_token special_mask_id = LLAMA_TOKEN_NULL;
 
     llama_token linefeed_id = 13;
@@ -1352,7 +1351,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             special_unk_id  = LLAMA_TOKEN_NULL;
             special_sep_id  = LLAMA_TOKEN_NULL;
             special_pad_id  = LLAMA_TOKEN_NULL;
-            special_cls_id  = LLAMA_TOKEN_NULL;
             special_mask_id = LLAMA_TOKEN_NULL;
             linefeed_id     = LLAMA_TOKEN_NULL;
 
@@ -1374,18 +1372,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             special_unk_id  = 0;
             special_sep_id  = LLAMA_TOKEN_NULL;
             special_pad_id  = LLAMA_TOKEN_NULL;
-            special_cls_id  = LLAMA_TOKEN_NULL;
             special_mask_id = LLAMA_TOKEN_NULL;
         } else if (tokenizer_model == "bert") {
             type = LLAMA_VOCAB_TYPE_WPM;
 
             // default special tokens
-            special_bos_id  = LLAMA_TOKEN_NULL;
+            special_bos_id  = 101;
             special_eos_id  = LLAMA_TOKEN_NULL;
             special_unk_id  = 100;
             special_sep_id  = 102;
             special_pad_id  = 0;
-            special_cls_id  = 101;
             special_mask_id = 103;
         } else if (tokenizer_model == "gpt2") {
             type = LLAMA_VOCAB_TYPE_BPE;
@@ -1420,7 +1416,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             special_unk_id  = LLAMA_TOKEN_NULL;
             special_sep_id  = LLAMA_TOKEN_NULL;
             special_pad_id  = LLAMA_TOKEN_NULL;
-            special_cls_id  = LLAMA_TOKEN_NULL;
             special_mask_id = LLAMA_TOKEN_NULL;
         } else if (tokenizer_model == "t5") {
             type = LLAMA_VOCAB_TYPE_UGM;
@@ -1431,7 +1426,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             special_unk_id  = 2;
             special_sep_id  = LLAMA_TOKEN_NULL;
             special_pad_id  = 0;
-            special_cls_id  = LLAMA_TOKEN_NULL;
             special_mask_id = LLAMA_TOKEN_NULL;
 
             const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
@@ -1712,7 +1706,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             { LLM_KV_TOKENIZER_UNK_ID,     special_unk_id     },
             { LLM_KV_TOKENIZER_SEP_ID,     special_sep_id     },
             { LLM_KV_TOKENIZER_PAD_ID,     special_pad_id     },
-            { LLM_KV_TOKENIZER_CLS_ID,     special_cls_id     },
             { LLM_KV_TOKENIZER_MASK_ID,    special_mask_id    },
             { LLM_KV_TOKENIZER_FIM_PRE_ID, special_fim_pre_id },
             { LLM_KV_TOKENIZER_FIM_SUF_ID, special_fim_suf_id },
@@ -2406,8 +2399,8 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
         case LLAMA_VOCAB_TYPE_WPM:
             {
                 if (add_special) {
-                    GGML_ASSERT(special_cls_id != LLAMA_TOKEN_NULL);
-                    output.push_back(special_cls_id);
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
                 }
 
                 llm_tokenizer_wpm_session session(vocab);
@@ -2700,7 +2693,6 @@ void llama_vocab::impl::print_info() const {
     if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token[special_unk_id].text.c_str() );  }
     if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token[special_sep_id].text.c_str() );  }
     if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token[special_pad_id].text.c_str() );  }
-    if (special_cls_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, special_cls_id,     id_to_token[special_cls_id].text.c_str() );  }
     if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token[special_mask_id].text.c_str() ); }
 
     if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token[linefeed_id].text.c_str() ); }
@@ -2834,7 +2826,7 @@ llama_token_attr llama_vocab::token_get_attr(llama_token id) const {
 }
 
 llama_token llama_vocab::token_bos() const {
-    return pimpl->type != LLAMA_VOCAB_TYPE_WPM ? pimpl->special_bos_id : pimpl->special_cls_id;
+    return pimpl->special_bos_id;
 }
 
 llama_token llama_vocab::token_eos() const {
@@ -2853,10 +2845,6 @@ llama_token llama_vocab::token_unk() const {
     return pimpl->special_unk_id;
 }
 
-llama_token llama_vocab::token_cls() const {
-    return pimpl->special_cls_id;
-}
-
 llama_token llama_vocab::token_sep() const {
     return pimpl->special_sep_id;
 }
@@ -3069,8 +3057,9 @@ llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
     return vocab->token_eot();
 }
 
+// deprecated
 llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
-    return vocab->token_cls();
+    return vocab->token_bos();
 }
 
 llama_token llama_vocab_sep(const struct llama_vocab * vocab) {
@@ -3159,7 +3148,8 @@ llama_token llama_token_eot(const struct llama_vocab * vocab) {
 
 // deprecated
 llama_token llama_token_cls(const struct llama_vocab * vocab) {
-    return llama_vocab_cls(vocab);
+    //return llama_vocab_cls(vocab);
+    return llama_vocab_bos(vocab); // avoid deprecation warning
 }
 
 // deprecated
index 020f2b5337567460d0c10ed58ff9ab5ef6fb164c..5ce355214346f12313c40f689ebfe3aa4b8d1a12 100644 (file)
@@ -53,7 +53,6 @@ struct llama_vocab {
     llama_token token_eot() const;
     llama_token token_eom() const;
     llama_token token_unk() const;
-    llama_token token_cls() const;
     llama_token token_sep() const;
     llama_token token_nl () const;
     llama_token token_pad() const;