]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : cache llama_token_to_piece (#7587)
authorGeorgi Gerganov <redacted>
Thu, 30 May 2024 16:01:41 +0000 (19:01 +0300)
committerGitHub <redacted>
Thu, 30 May 2024 16:01:41 +0000 (02:01 +1000)
* llama : cache llama_token_to_piece

ggml-ci

* llama : use vectors and avoid has_cache

ggml-ci

* llama : throw on unknown tokenizer types

ggml-ci

* llama : print a log of the total cache size

llama.cpp
llama.h

index e7412de4b6cac3757ea7be05461db09064b08cfa..40d2ec2c967f2a1c8e63fd647e331ccefd4b212a 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1702,12 +1702,13 @@ struct llama_mlock {
 };
 using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
 
-static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
+// NOTE: avoid ever using this except for building the token_to_piece caches
+static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
     std::vector<char> result(8, 0);
-    const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
+    const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special);
     if (n_tokens < 0) {
         result.resize(-n_tokens);
-        int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
+        int check = llama_token_to_piece(model, token, result.data(), result.size(), special);
         GGML_ASSERT(check == -n_tokens);
     }
     else {
@@ -2162,7 +2163,9 @@ struct llama_vocab {
     std::unordered_map<token, id> token_to_id;
     std::vector<token_data>       id_to_token;
 
-    std::vector<id> special_tokens_cache;
+    std::vector<id>    cache_special_tokens;
+    std::vector<token> cache_token_to_piece;         // llama_token_to_piece(special = false);
+    std::vector<token> cache_token_to_piece_special; // llama_token_to_piece(special = true);
 
     std::map<std::pair<std::string, std::string>, int> bpe_ranks;
 
@@ -4592,20 +4595,14 @@ static void llm_load_vocab(
             vocab.special_cls_id  = 101;
             vocab.special_mask_id = 103;
             vocab.add_space_prefix = false;
-        } else {
-            if (tokenizer_model == "gpt2") {
-                vocab.type = LLAMA_VOCAB_TYPE_BPE;
+        } else if (tokenizer_model == "gpt2") {
+            vocab.type = LLAMA_VOCAB_TYPE_BPE;
 
-                const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
-                if (add_space_prefix_keyidx != -1) {
-                    vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
-                }
-            } else {
-                LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str());
-                LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
-                vocab.type = LLAMA_VOCAB_TYPE_SPM;
-                return;
+            const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
+            if (add_space_prefix_keyidx != -1) {
+                vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
             }
+
             // read bpe merges and populate bpe ranks
             const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
             if (merges_keyidx == -1) {
@@ -4639,6 +4636,8 @@ static void llm_load_vocab(
             vocab.special_pad_id  = -1;
             vocab.special_cls_id  = -1;
             vocab.special_mask_id = -1;
+        } else {
+            throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
         }
 
         // for now, only BPE models have pre-tokenizers
@@ -4833,17 +4832,38 @@ static void llm_load_vocab(
     {
         for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
             if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
-                vocab.special_tokens_cache.push_back(id);
+                vocab.cache_special_tokens.push_back(id);
             }
         }
 
-        std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
+        std::sort( vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
             [&] (const llama_vocab::id a, const llama_vocab::id b) {
                 return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
             }
         );
 
-        LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
+    }
+
+    // build token to piece caches
+    {
+        size_t size_cache = 0;
+
+        std::vector<llama_vocab::token> cache_token_to_piece        (n_vocab);
+        std::vector<llama_vocab::token> cache_token_to_piece_special(n_vocab);
+
+        for (uint32_t id = 0; id < n_vocab; ++id) {
+            cache_token_to_piece[id]         = llama_token_to_piece(&model, id, false);
+            cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
+
+            size_cache += cache_token_to_piece[id].size();
+            size_cache += cache_token_to_piece_special[id].size();
+        }
+
+        std::swap(vocab.cache_token_to_piece,         cache_token_to_piece);
+        std::swap(vocab.cache_token_to_piece_special, cache_token_to_piece_special);
+
+        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
     }
 }
 
@@ -13233,7 +13253,7 @@ struct fragment_buffer_variant {
 
 static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
     // for each special token
-    for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
+    for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
         const auto & special_token = vocab.id_to_token[special_id].text;
 
         // for each text fragment
@@ -14392,7 +14412,7 @@ void llama_sample_repetition_penalties(
 
 void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
     GGML_ASSERT(ctx);
-    const int64_t t_start_sample_us = ggml_time_us();
+    int64_t t_start_sample_us = ggml_time_us();
 
     bool allow_eog = false;
     for (const auto & stack : grammar->stacks) {
@@ -14404,12 +14424,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
 
     std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
     candidates_decoded.reserve(candidates->size);
-    std::vector<llama_grammar_candidate>                              candidates_grammar;
+
+    std::vector<llama_grammar_candidate> candidates_grammar;
     candidates_grammar.reserve(candidates->size);
 
     for (size_t i = 0; i < candidates->size; ++i) {
-        const llama_token id    = candidates->data[i].id;
-        const std::string piece = llama_token_to_piece(ctx, id, false);
+        const llama_token id      = candidates->data[i].id;
+        const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id);
 
         if (llama_token_is_eog(&ctx->model, id)) {
             if (!allow_eog) {
@@ -14609,7 +14630,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
         GGML_ASSERT(false);
     }
 
-    const std::string piece = llama_token_to_piece(ctx, token, false);
+    const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token);
 
     // Note terminating 0 in decoded string
     const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);
@@ -18292,69 +18313,83 @@ static std::string llama_decode_text(const std::string & text) {
 
 // does not write null-terminator to buf
 int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
+    // if we have a cache - use it
+    {
+        const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
+
+        if (!cache.empty()) {
+            const auto & res = cache.at(token);
+            if (length < (int) res.size()) {
+                return -(int) res.size();
+            }
+            memcpy(buf, res.c_str(), res.size());
+            return res.size();
+        }
+    }
+
     if (0 <= token && token < llama_n_vocab(model)) {
         switch (llama_vocab_get_type(model->vocab)) {
-        case LLAMA_VOCAB_TYPE_WPM:
-        case LLAMA_VOCAB_TYPE_SPM: {
-            // NOTE: we accept all unsupported token types,
-            // suppressing them like CONTROL tokens.
-            if (llama_is_normal_token(model->vocab, token)) {
-                std::string result = model->vocab.id_to_token[token].text;
-                llama_unescape_whitespace(result);
-                if (length < (int) result.length()) {
-                    return -(int) result.length();
-                }
-                memcpy(buf, result.c_str(), result.length());
-                return result.length();
-            } else if (
-                    (llama_is_user_defined_token(model->vocab, token)) ||
-                    (llama_is_control_token     (model->vocab, token) && special)) {
-                std::string result = model->vocab.id_to_token[token].text;
-                if (length < (int) result.length()) {
-                    return -(int) result.length();
-                }
-                memcpy(buf, result.c_str(), result.length());
-                return result.length();
-            } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
-                if (length < 3) {
-                    return -3;
-                }
-                memcpy(buf, "\xe2\x96\x85", 3);
-                return 3;
-            } else if (llama_is_byte_token(model->vocab, token)) {
-                if (length < 1) {
-                    return -1;
+            case LLAMA_VOCAB_TYPE_WPM:
+            case LLAMA_VOCAB_TYPE_SPM: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (llama_is_normal_token(model->vocab, token)) {
+                    std::string result = model->vocab.id_to_token[token].text;
+                    llama_unescape_whitespace(result);
+                    if (length < (int) result.length()) {
+                        return -(int) result.length();
+                    }
+                    memcpy(buf, result.c_str(), result.length());
+                    return result.length();
+                } else if (
+                        (llama_is_user_defined_token(model->vocab, token)) ||
+                        (llama_is_control_token     (model->vocab, token) && special)) {
+                    std::string result = model->vocab.id_to_token[token].text;
+                    if (length < (int) result.length()) {
+                        return -(int) result.length();
+                    }
+                    memcpy(buf, result.c_str(), result.length());
+                    return result.length();
+                } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
+                    if (length < 3) {
+                        return -3;
+                    }
+                    memcpy(buf, "\xe2\x96\x85", 3);
+                    return 3;
+                } else if (llama_is_byte_token(model->vocab, token)) {
+                    if (length < 1) {
+                        return -1;
+                    }
+                    buf[0] = llama_token_to_byte(model->vocab, token);
+                    return 1;
                 }
-                buf[0] = llama_token_to_byte(model->vocab, token);
-                return 1;
+                break;
             }
-            break;
-        }
-        case LLAMA_VOCAB_TYPE_BPE: {
-            // NOTE: we accept all unsupported token types,
-            // suppressing them like CONTROL tokens.
-            if (llama_is_normal_token(model->vocab, token)) {
-                std::string result = model->vocab.id_to_token[token].text;
-                result = llama_decode_text(result);
-                if (length < (int) result.length()) {
-                    return -(int) result.length();
-                }
-                memcpy(buf, result.c_str(), result.length());
-                return result.length();
-            } else if (
-                    (llama_is_user_defined_token(model->vocab, token)) ||
-                    (llama_is_control_token     (model->vocab, token) && special)) {
-                std::string result = model->vocab.id_to_token[token].text;
-                if (length < (int) result.length()) {
-                    return -(int) result.length();
+            case LLAMA_VOCAB_TYPE_BPE: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (llama_is_normal_token(model->vocab, token)) {
+                    std::string result = model->vocab.id_to_token[token].text;
+                    result = llama_decode_text(result);
+                    if (length < (int) result.length()) {
+                        return -(int) result.length();
+                    }
+                    memcpy(buf, result.c_str(), result.length());
+                    return result.length();
+                } else if (
+                        (llama_is_user_defined_token(model->vocab, token)) ||
+                        (llama_is_control_token     (model->vocab, token) && special)) {
+                    std::string result = model->vocab.id_to_token[token].text;
+                    if (length < (int) result.length()) {
+                        return -(int) result.length();
+                    }
+                    memcpy(buf, result.c_str(), result.length());
+                    return result.length();
                 }
-                memcpy(buf, result.c_str(), result.length());
-                return result.length();
+                break;
             }
-            break;
-        }
-        default:
-            GGML_ASSERT(false);
+            default:
+                GGML_ASSERT(false);
         }
     }
     return 0;
diff --git a/llama.h b/llama.h
index 3e4474bb94e9affdb6c70197eae232ec31546d7a..95105c28e5e42b038352695d51b51c80d4935350 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -424,8 +424,8 @@ extern "C" {
 
     LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
 
-    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model   * model);
-    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model   * model);
+    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model * model);
+    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model * model);
 
     LLAMA_API int32_t llama_n_vocab    (const struct llama_model * model);
     LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);