]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix detokenization of non-special added-tokens (#4916)
authorGeorgi Gerganov <redacted>
Sat, 13 Jan 2024 16:47:38 +0000 (18:47 +0200)
committerGitHub <redacted>
Sat, 13 Jan 2024 16:47:38 +0000 (18:47 +0200)
Co-authored-by: goerch <redacted>
llama.cpp

index 2754560884e5f662cc99ad024b0ef55dafe69dee..2190ea7aa92c2325792aae35886e317ec111fa75 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -10305,6 +10305,8 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
     if (0 <= token && token < llama_n_vocab(model)) {
         switch (llama_vocab_get_type(model->vocab)) {
         case LLAMA_VOCAB_TYPE_SPM: {
+            // NOTE: we accept all unsupported token types,
+            // suppressing them like CONTROL tokens.
             if (llama_is_normal_token(model->vocab, token)) {
                 std::string result = model->vocab.id_to_token[token].text;
                 llama_unescape_whitespace(result);
@@ -10313,6 +10315,13 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
                 }
                 memcpy(buf, result.c_str(), result.length());
                 return result.length();
+            } else if (llama_is_user_defined_token(model->vocab, token)) {
+                std::string result = model->vocab.id_to_token[token].text;
+                if (length < (int) result.length()) {
+                    return -result.length();
+                }
+                memcpy(buf, result.c_str(), result.length());
+                return result.length();
             } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
                 if (length < 3) {
                     return -3;
@@ -10327,14 +10336,12 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
                 }
                 buf[0] = llama_token_to_byte(model->vocab, token);
                 return 1;
-            } else {
-                // TODO: for now we accept all unsupported token types,
-                // suppressing them like CONTROL tokens.
-                // GGML_ASSERT(false);
             }
             break;
         }
         case LLAMA_VOCAB_TYPE_BPE: {
+            // NOTE: we accept all unsupported token types,
+            // suppressing them like CONTROL tokens.
             if (llama_is_normal_token(model->vocab, token)) {
                 std::string result = model->vocab.id_to_token[token].text;
                 result = llama_decode_text(result);
@@ -10343,12 +10350,15 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
                 }
                 memcpy(buf, result.c_str(), result.length());
                 return result.length();
+            } else if (llama_is_user_defined_token(model->vocab, token)) {
+                std::string result = model->vocab.id_to_token[token].text;
+                if (length < (int) result.length()) {
+                    return -result.length();
+                }
+                memcpy(buf, result.c_str(), result.length());
+                return result.length();
             } else if (llama_is_control_token(model->vocab, token)) {
                 ;
-            } else {
-                // TODO: for now we accept all unsupported token types,
-                // suppressing them like CONTROL tokens.
-                // GGML_ASSERT(false);
             }
             break;
         }