]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add tokenizer test + revert to C++11 (#355)
authorGeorgi Gerganov <redacted>
Tue, 21 Mar 2023 15:29:41 +0000 (17:29 +0200)
committerGitHub <redacted>
Tue, 21 Mar 2023 15:29:41 +0000 (17:29 +0200)
* Add test-tokenizer-0 to do a few tokenizations - feel free to expand
* Added option to convert-pth-to-ggml.py script to dump just the vocabulary
* Added ./models/ggml-vocab.bin containing just LLaMA vocab data (used for tests)
* Added utility to load vocabulary file from previous point (temporary implementation)
* Avoid using std::string_view and drop back to C++11 (hope I didn't break something)
* Rename gpt_vocab -> llama_vocab
* All CMake binaries go into ./bin/ now

.github/workflows/build.yml
CMakeLists.txt
Makefile
convert-pth-to-ggml.py
main.cpp
models/ggml-vocab.bin [new file with mode: 0644]
quantize.cpp
tests/CMakeLists.txt [new file with mode: 0644]
tests/test-tokenizer-0.cpp [new file with mode: 0644]
utils.cpp
utils.h

index 9c1de58234e06b7635ef7445282a4d8c9ea17a85..5b1b5ddfbce64e5ab7789f8f717bc780ba8a5d97 100644 (file)
@@ -54,6 +54,7 @@ jobs:
           cd build
           cmake ..
           cmake --build . --config Release
+          ctest --output-on-failure
 
   macOS-latest-make:
     runs-on: macos-latest
@@ -90,6 +91,7 @@ jobs:
           cd build
           cmake ..
           cmake --build . --config Release
+          ctest --output-on-failure
 
   windows-latest-cmake:
     runs-on: windows-latest
@@ -106,6 +108,7 @@ jobs:
           cd build
           cmake ..
           cmake --build . --config Release
+          ctest --output-on-failure
 
       - name: Get commit hash
         id: commit
index 7f46513c832db92b313061f6d81cde4647181477..bf0e77b4a9b9de382880be8dedcd555d74121680 100644 (file)
@@ -1,11 +1,37 @@
-cmake_minimum_required(VERSION 3.12)
+cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason
 project("llama.cpp" C CXX)
 
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
 if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
     set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
     set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
 endif()
 
+set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
+
+if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
+    set(LLAMA_STANDALONE ON)
+
+    # configure project version
+    # TODO
+else()
+    set(LLAMA_STANDALONE OFF)
+endif()
+
+if (EMSCRIPTEN)
+    set(BUILD_SHARED_LIBS_DEFAULT OFF)
+
+    option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
+else()
+    if (MINGW)
+        set(BUILD_SHARED_LIBS_DEFAULT OFF)
+    else()
+        set(BUILD_SHARED_LIBS_DEFAULT ON)
+    endif()
+endif()
+
+
 #
 # Option list
 #
@@ -34,6 +60,9 @@ option(LLAMA_FMA                    "llama: enable FMA"
 option(LLAMA_ACCELERATE             "llama: enable Accelerate framework"                    ON)
 option(LLAMA_OPENBLAS               "llama: use OpenBLAS"                                   OFF)
 
+option(LLAMA_BUILD_TESTS            "llama: build tests"    ${LLAMA_STANDALONE})
+option(LLAMA_BUILD_EXAMPLES         "llama: build examples" ${LLAMA_STANDALONE})
+
 #
 # Compile flags
 #
@@ -187,17 +216,19 @@ add_executable(llama main.cpp)
 
 add_executable(quantize quantize.cpp)
 
-add_library(ggml OBJECT
-            ggml.c
-            ggml.h)
-
 add_library(utils OBJECT
             utils.cpp
             utils.h)
 
+target_include_directories(utils PUBLIC .)
+target_compile_features(utils PUBLIC cxx_std_11) # don't bump
+
+add_library(ggml OBJECT
+            ggml.c
+            ggml.h)
+
 target_include_directories(ggml PUBLIC .)
-target_compile_features(ggml PUBLIC c_std_11)
-target_compile_features(utils PUBLIC cxx_std_17)
+target_compile_features(ggml PUBLIC c_std_11) # don't bump
 
 #
 # Linking
@@ -206,3 +237,16 @@ target_compile_features(utils PUBLIC cxx_std_17)
 target_link_libraries(ggml PRIVATE Threads::Threads ${LLAMA_EXTRA_LIBS})
 target_link_libraries(llama PRIVATE ggml utils)
 target_link_libraries(quantize PRIVATE ggml utils)
+
+#
+# programs, examples and tests
+#
+
+if (LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
+    enable_testing()
+    add_subdirectory(tests)
+endif ()
+
+#if (LLAMA_BUILD_EXAMPLES)
+#    add_subdirectory(examples)
+#endif()
index ec2eb75696d9546354aba902fc4a3288ef4b29f2..dffcdbde71d070007a5ab4d5cd19c070cd35f083 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -30,8 +30,9 @@ endif
 # Compile flags
 #
 
+# keep standard at C11 and C++11
 CFLAGS   = -I.              -O3 -DNDEBUG -std=c11   -fPIC
-CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++17 -fPIC
+CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
 LDFLAGS  =
 
 # OS specific
index 108eb1fccfe634b26ff2e9cba98c535c7e28205b..c506676fc7de8503b7cc2819f61c4c40925c27ba 100644 (file)
 #   - Name (char[name_length])
 #   - Data (float[n_dims])
 #
-# By default, the bigger matrices are converted to 16-bit floats.
-# This can be disabled by adding the "use-f32" CLI argument.
-#
 # At the start of the ggml file we write the model parameters
 # and vocabulary.
 #
+
 import argparse
 import os
 import sys
@@ -23,6 +21,7 @@ import json
 import struct
 import numpy as np
 import torch
+
 from sentencepiece import SentencePieceProcessor
 
 def parse_args():
@@ -30,6 +29,7 @@ def parse_args():
     parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
     parser.add_argument('dir_model', help='directory containing the model checkpoint')
     parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)')
+    parser.add_argument('vocab_only', type=bool, default=False, help='only write vocab to file')
     return parser.parse_args()
 
 def get_n_parts(dim):
@@ -134,6 +134,27 @@ def main():
     ftype_str = ["f32", "f16"]
 
     hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
+
+    # if only writing vocab to file
+    if args.vocab_only:
+
+        fname_model = f"{dir_model}/consolidated.00.pth"
+        fname_out = f"{dir_model}/ggml-vocab.bin"
+
+        print(f"Extracting only the vocab from '{fname_model}'\n")
+
+        model = torch.load(fname_model, map_location="cpu")
+
+        with open(fname_out, "wb") as fout:
+            fout.write(struct.pack("i", hparams["vocab_size"]))
+            write_tokens(fout, tokenizer)
+
+        del model
+
+        print(f"Done. Output file: {fname_out}\n")
+
+        return
+
     n_parts = get_n_parts(hparams["dim"])
 
     for p in range(n_parts):
@@ -151,6 +172,7 @@ def main():
             process_and_write_variables(fout, model, ftype)
 
         del model
+
         print(f"Done. Output file: {fname_out}, (part {p})\n")
 
 if __name__ == "__main__":
index 3321818d3e27f5638c6bcec1c84ed5ee4fe0f471..e97611e2882c60e0266f993f0b619ef2335aa4b2 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -90,7 +90,7 @@ struct llama_model {
 };
 
 // load the model's weights from a file
-bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
+bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
     fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
 
     std::vector<char> f_buf(1024*1024);
@@ -544,9 +544,9 @@ bool llama_eval(
         const llama_model & model,
         const int n_threads,
         const int n_past,
-        const std::vector<gpt_vocab::id> & embd_inp,
-              std::vector<float>         & embd_w,
-              size_t                     & mem_per_token) {
+        const std::vector<llama_vocab::id> & embd_inp,
+              std::vector<float>           & embd_w,
+              size_t                       & mem_per_token) {
     const int N = embd_inp.size();
 
     const auto & hparams = model.hparams;
@@ -832,7 +832,7 @@ int main(int argc, char ** argv) {
 
     int64_t t_load_us = 0;
 
-    gpt_vocab vocab;
+    llama_vocab vocab;
     llama_model model;
 
     // load the model
@@ -864,13 +864,13 @@ int main(int argc, char ** argv) {
     // Add a space in front of the first character to match OG llama tokenizer behavior
     params.prompt.insert(0, 1, ' ');
     // tokenize the prompt
-    std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
+    std::vector<llama_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
 
     params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
 
     // prefix & suffix for instruct mode
-    const std::vector<gpt_vocab::id> inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true);
-    const std::vector<gpt_vocab::id> inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false);
+    const std::vector<llama_vocab::id> inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true);
+    const std::vector<llama_vocab::id> inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false);
 
     // in instruct mode, we inject a prefix and a suffix to each input by the user
     if (params.instruct) {
@@ -879,8 +879,8 @@ int main(int argc, char ** argv) {
     }
 
     // tokenize the reverse prompt
-    std::vector<std::vector<gpt_vocab::id>> antipromptv_inp;
-    
+    std::vector<std::vector<llama_vocab::id>> antipromptv_inp;
+
     for (auto antiprompt : params.antiprompt) {
         antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
     }
@@ -925,14 +925,14 @@ int main(int argc, char ** argv) {
     fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
     fprintf(stderr, "\n\n");
 
-    std::vector<gpt_vocab::id> embd;
+    std::vector<llama_vocab::id> embd;
 
     // determine the required inference memory per token:
     size_t mem_per_token = 0;
     llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
 
     int last_n_size = params.repeat_last_n;
-    std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
+    std::vector<llama_vocab::id> last_n_tokens(last_n_size);
     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
 
     if (params.interactive) {
@@ -980,7 +980,7 @@ int main(int argc, char ** argv) {
 
             const int n_vocab = model.hparams.n_vocab;
 
-            gpt_vocab::id id = 0;
+            llama_vocab::id id = 0;
 
             {
                 const int64_t t_start_sample_us = ggml_time_us();
@@ -1066,7 +1066,7 @@ int main(int argc, char ** argv) {
                 } while (another_line);
                 if (params.use_color) printf(ANSI_COLOR_RESET);
 
-                std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
+                std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
                 embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
 
                 if (params.instruct) {
diff --git a/models/ggml-vocab.bin b/models/ggml-vocab.bin
new file mode 100644 (file)
index 0000000..aba94bd
Binary files /dev/null and b/models/ggml-vocab.bin differ
index 07db33a3caefe0b1baa85a599d688f59930cd1ec..b90f34f480cb35a8674e0e6356e771f07b20d17c 100644 (file)
@@ -44,7 +44,7 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna
         return false;
     }
 
-    gpt_vocab vocab;
+    llama_vocab vocab;
 
     printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
 
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
new file mode 100644 (file)
index 0000000..a2c1e3f
--- /dev/null
@@ -0,0 +1,4 @@
+set(TEST_TARGET test-tokenizer-0)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE utils)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp
new file mode 100644 (file)
index 0000000..6bc49f2
--- /dev/null
@@ -0,0 +1,69 @@
+#include "utils.h"
+
+#include <cstdio>
+#include <string>
+#include <map>
+
+static const std::map<std::string, std::vector<llama_vocab::id>> k_tests = {
+    { "Hello World",        { 1,  10994,   2787, }, },
+    { " Hello World",       { 1,  15043,   2787, }, },
+    { " Hello World!",      { 1,  15043,   2787,  29991, }, },
+    { " this is 🦙.cpp",    { 1,    445,    338,  29871,    243,    162,    169,    156,  29889,   8223, }, },
+    { "w048 7tuijk dsdfhu", { 1,  29893,  29900,  29946,  29947,  29871,  29955,   9161,  13535,  18031,   2176,   6905, }, },
+    { "нещо на Български",  { 1,    821,   4851,    665,   1386,  29713,   1305, }, },
+};
+
+int main(int argc, char **argv) {
+    if (argc < 2) {
+        fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
+        return 1;
+    }
+
+    const std::string fname = argv[1];
+
+    fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
+
+    llama_vocab vocab;
+
+    if (!llama_vocab_load(fname, vocab)) {
+        fprintf(stderr, "%s : failed to load vocab from: '%s'\n", __func__, fname.c_str());
+        return 1;
+    }
+
+    const int n_vocab = vocab.id_to_token.size();
+
+    if (n_vocab != 32000) {
+        fprintf(stderr, "%s : expected 32000 tokens, got %d\n", __func__, n_vocab);
+        return 2;
+    }
+
+    for (const auto & test_kv : k_tests) {
+        const auto res = llama_tokenize(vocab, test_kv.first, true);
+
+        bool correct = res.size() == test_kv.second.size();
+
+        for (int i = 0; i < (int) res.size() && correct; ++i) {
+            if (res[i] != test_kv.second[i]) {
+                correct = false;
+            }
+        }
+
+        if (!correct) {
+            fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
+            fprintf(stderr, "%s : expected tokens: ", __func__);
+            for (const auto & t : test_kv.second) {
+                fprintf(stderr, "%6d, ", t);
+            }
+            fprintf(stderr, "\n");
+            fprintf(stderr, "%s : got tokens:      ", __func__);
+            for (const auto & t : res) {
+                fprintf(stderr, "%6d, ", t);
+            }
+            fprintf(stderr, "\n");
+
+            return 3;
+        }
+    }
+
+    return 0;
+}
index 188f114e9f1cd4f7489ea9a343785495e00b98dc..4843b4f557f12c1b277819292ac5279b0489073f 100644 (file)
--- a/utils.cpp
+++ b/utils.cpp
@@ -240,61 +240,6 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
     return result;
 }
 
-std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
-    std::vector<std::string> words;
-
-    // first split the text into words
-    {
-        std::string str = text;
-        std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
-
-        std::regex re(pat);
-        std::smatch m;
-
-        while (std::regex_search(str, m, re)) {
-            for (auto x : m) {
-                words.push_back(x);
-            }
-            str = m.suffix();
-        }
-    }
-
-    // find the longest tokens that form the words:
-    std::vector<gpt_vocab::id> tokens;
-    for (const auto & word : words) {
-        if (word.size() == 0) continue;
-
-        int i = 0;
-        int n = word.size();
-        while (i < n) {
-            int j = n;
-            while (j > i) {
-                auto it = vocab.token_to_id.find(word.substr(i, j-i));
-                if (it != vocab.token_to_id.end()) {
-                    tokens.push_back(it->second);
-                    i = j;
-                    break;
-                }
-                --j;
-            }
-            if (i == n) {
-                break;
-            }
-            if (j == i) {
-                auto sub = word.substr(i, 1);
-                if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
-                    tokens.push_back(vocab.token_to_id.at(sub));
-                } else {
-                    fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
-                }
-                ++i;
-            }
-        }
-    }
-
-    return tokens;
-}
-
 static size_t utf8_len(char src) {
     const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
     uint8_t highbits = static_cast<uint8_t>(src) >> 4;
@@ -305,7 +250,8 @@ struct llama_sp_symbol {
     using index = int;
     index prev;
     index next;
-    std::string_view text;
+    const char * text;
+    size_t n;
 };
 
 struct llama_sp_bigram {
@@ -322,19 +268,23 @@ struct llama_sp_bigram {
     size_t size;
 };
 
+// original implementation:
+// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
 struct llama_tokenizer {
-    llama_tokenizer(const gpt_vocab & vocab): vocab_(vocab) {}
+    llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
 
-    void tokenize(std::string_view text, std::vector<gpt_vocab::id> & output) {
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
         // split string into utf8 chars
         int index = 0;
-        while (!text.empty()) {
+        size_t offs = 0;
+        while (offs < text.size()) {
             llama_sp_symbol sym;
-            size_t char_len = std::min(text.size(), utf8_len(text.data()[0]));
-            sym.text = std::string_view(text.data(), char_len);
+            size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
+            sym.text = text.c_str() + offs;
+            sym.n = char_len;
+            offs += char_len;
             sym.prev = index - 1;
-            text.remove_prefix(char_len);
-            sym.next = text.empty() ? -1 : index + 1;
+            sym.next = offs == text.size() ? -1 : index + 1;
             index++;
             symbols_.emplace_back(std::move(sym));
         }
@@ -353,14 +303,16 @@ struct llama_tokenizer {
             auto & right_sym = symbols_[bigram.right];
 
             // if one of the symbols already got merged, skip it.
-            if (left_sym.text.empty() || right_sym.text.empty() ||
-                left_sym.text.size() + right_sym.text.size() != bigram.size) {
+            if (left_sym.n == 0 || right_sym.n == 0 ||
+                left_sym.n + right_sym.n != bigram.size) {
                 continue;
             }
 
             // merge the right sym into the left one
-            left_sym.text = std::string_view(left_sym.text.data(), left_sym.text.size() + right_sym.text.size());
-            right_sym.text = std::string_view("");
+            left_sym.n += right_sym.n;
+            right_sym.n = 0;
+
+            //printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
 
             // remove the right sym from the chain
             left_sym.next = right_sym.next;
@@ -374,13 +326,13 @@ struct llama_tokenizer {
         }
 
         for (int i = 0; i != -1; i = symbols_[i].next) {
-            auto& symbol = symbols_[i];
-            auto token = vocab_.token_to_id.find(std::string(symbol.text));
+            auto & symbol = symbols_[i];
+            auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n));
 
             if (token == vocab_.token_to_id.end()) {
                 // output any symbols that did not form tokens as bytes.
-                for (int j = 0; j < symbol.text.size(); ++j) {
-                    gpt_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
+                for (int j = 0; j < (int) symbol.n; ++j) {
+                    llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
                     output.push_back(token_id);
                 }
             } else {
@@ -395,8 +347,8 @@ private:
             return;
         }
 
-        std::string_view text(symbols_[left].text.data(), symbols_[left].text.size() + symbols_[right].text.size());
-        auto token = vocab_.token_to_id.find(std::string(text));
+        const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n);
+        auto token = vocab_.token_to_id.find(text);
 
         if (token == vocab_.token_to_id.end()) {
             return;
@@ -416,14 +368,52 @@ private:
         work_queue_.push(bigram);
     }
 
-    const gpt_vocab & vocab_;
+    const llama_vocab & vocab_;
     std::vector<llama_sp_symbol> symbols_;
     llama_sp_bigram::queue work_queue_;
 };
 
-std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos) {
+// TODO: temporary code duplication with llama.cpp
+//       will resolve after #77 is merged
+bool llama_vocab_load(const std::string & fname, llama_vocab & vocab) {
+    std::ifstream fin(fname, std::ios::binary);
+    if (!fin.is_open()) {
+        return false;
+    }
+
+    int n_vocab = 0;
+    fin.read((char *) &n_vocab, sizeof(n_vocab));
+
+    std::string word;
+    std::vector<char> tmp(64);
+
+    for (int i = 0; i < n_vocab; i++) {
+        uint32_t len;
+        fin.read((char *) &len, sizeof(len));
+
+        word.resize(len);
+        if (len > 0) {
+            tmp.resize(len);
+            fin.read(tmp.data(), len);
+            word.assign(tmp.data(), len);
+        } else {
+            word.clear();
+        }
+
+        float score;
+        fin.read((char *) &score, sizeof(score));
+
+        vocab.token_to_id[word] = i;
+        vocab.id_to_token[i] = word;
+        vocab.score[i] = score;
+    }
+
+    return true;
+}
+
+std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
     llama_tokenizer tokenizer(vocab);
-    std::vector<gpt_vocab::id> output;
+    std::vector<llama_vocab::id> output;
 
     if (text.size() == 0) {
         return output;
@@ -437,42 +427,22 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_v
     return output;
 }
 
-bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
-    printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
-
-    vocab.token_to_id = ::json_parse(fname);
-
-    for (const auto & kv : vocab.token_to_id) {
-        vocab.id_to_token[kv.second] = kv.first;
-    }
-
-    printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
-
-    // print the vocabulary
-    //for (auto kv : vocab.token_to_id) {
-    //    printf("'%s' -> %d\n", kv.first.data(), kv.second);
-    //}
-
-    return true;
-}
-
-
-void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k) {
+void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k) {
     // find the top K tokens
     std::partial_sort(
             logits_id.begin(),
             logits_id.begin() + top_k, logits_id.end(),
-            [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
+            [](const std::pair<double, llama_vocab::id> & a, const std::pair<double, llama_vocab::id> & b) {
         return a.first > b.first;
     });
 
     logits_id.resize(top_k);
 }
 
-gpt_vocab::id llama_sample_top_p_top_k(
-        const gpt_vocab & vocab,
+llama_vocab::id llama_sample_top_p_top_k(
+        const llama_vocab & vocab,
         const float * logits,
-        std::vector<gpt_vocab::id> & last_n_tokens,
+        std::vector<llama_vocab::id> & last_n_tokens,
         double repeat_penalty,
         int top_k,
         double top_p,
@@ -480,7 +450,7 @@ gpt_vocab::id llama_sample_top_p_top_k(
         std::mt19937 & rng) {
     int n_logits = vocab.id_to_token.size();
 
-    std::vector<std::pair<double, gpt_vocab::id>> logits_id;
+    std::vector<std::pair<double, llama_vocab::id>> logits_id;
     logits_id.reserve(n_logits);
 
     {
diff --git a/utils.h b/utils.h
index 65fe02ba15f73237fd95280eede259a24ef2a333..971cc0e982e771361468810a589e574132bc819a 100644 (file)
--- a/utils.h
+++ b/utils.h
@@ -60,7 +60,7 @@ std::string gpt_random_prompt(std::mt19937 & rng);
 // Vocab utils
 //
 
-struct gpt_vocab {
+struct llama_vocab {
     using id    = int32_t;
     using token = std::string;
 
@@ -74,34 +74,22 @@ void replace(std::string & str, const std::string & needle, const std::string &
 // poor-man's JSON parsing
 std::map<std::string, int32_t> json_parse(const std::string & fname);
 
-// split text into tokens
-//
-// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
-//
-// Regex (Python):
-// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
-//
-// Regex (C++):
-// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
-//
-std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
+// TODO: temporary until #77 is merged, need this now for some tokenizer tests
+bool llama_vocab_load(const std::string & fname, llama_vocab & vocab);
 
 // TODO: this is probably wrong, but I cannot figure out how this tokenizer works ..
 // ref: https://github.com/google/sentencepiece
-std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos);
-
-// load the tokens from encoder.json
-bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
+std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos);
 
 // sample next token given probabilities for each embedding
 //
 //   - consider only the top K tokens
 //   - from them, consider only the top tokens with cumulative probability > P
 //
-gpt_vocab::id llama_sample_top_p_top_k(
-        const gpt_vocab & vocab,
+llama_vocab::id llama_sample_top_p_top_k(
+        const llama_vocab & vocab,
         const float * logits,
-        std::vector<gpt_vocab::id> & last_n_tokens,
+        std::vector<llama_vocab::id> & last_n_tokens,
         double repeat_penalty,
         int top_k,
         double top_p,
@@ -109,7 +97,7 @@ gpt_vocab::id llama_sample_top_p_top_k(
         std::mt19937 & rng);
 
 // filer to top K tokens from list of logits
-void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k);
+void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k);
 
 //
 // Quantization