]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
infill. : fix tokenization (#3508)
authorvvhg1 <redacted>
Tue, 10 Oct 2023 07:31:21 +0000 (09:31 +0200)
committerGitHub <redacted>
Tue, 10 Oct 2023 07:31:21 +0000 (10:31 +0300)
* infill tokens correction

* serverinfill tokens correction

* removing any leading whitespace from infill suffix and removing leeading space token from suffix when params.escape

* removing any leading whitespace from infill suffix and removing leeading space token from suffix when params.escape

* only rm when params.escape, rm space if possible which is added back or rm added space token

* only rm when params.escape, rm space if possible which is added back or rm added space token

* Revert "only rm when params.escape, rm space if possible which is added back or rm added space token"

This reverts commit 63ba0b621f21077c0e3bc6ba6a327534123cb738.

* fix interactive prompt escaping and fix server infill leading space handling

* rm unnecessary bool check

examples/infill/infill.cpp
examples/server/server.cpp

index 9ec75ce425b2a6c771e115c27f780c96f17f6ce2..d994de5e850c3e7f3477545e81b84db2a5ba2fbe 100644 (file)
@@ -233,10 +233,22 @@ int main(int argc, char ** argv) {
     const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
     LOG("add_bos: %d\n", add_bos);
 
+    bool suff_rm_leading_spc = params.escape;
+    if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
+        params.input_suffix.erase(0, 1);
+        suff_rm_leading_spc = false;
+    }
     std::vector<llama_token> embd_inp;
-    std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
-    std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
+    std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
+    std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
+    const int space_token = 29871;
+    if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
+        inp_sfx.erase(inp_sfx.begin());
+    }
     inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
+    if (add_bos) {
+        inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
+    }
     inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
     embd_inp = inp_pfx;
     embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
@@ -627,10 +639,27 @@ int main(int argc, char ** argv) {
                 buffer.clear();
                 // done taking input, reset color
                 console::set_display(console::reset);
+
+                if (params.escape) {
+                    //process escape sequences, for the initial prompt this is done in common.cpp when we load the params, but for the interactive mode we need to do it here
+                    process_escapes(params.input_prefix);
+                    process_escapes(params.input_suffix);
+                }
+                suff_rm_leading_spc = params.escape;
+                if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
+                    params.input_suffix.erase(0, 1);
+                    suff_rm_leading_spc = false;
+                }
                 // tokenize new prefix and suffix
-                std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
-                std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
+                std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
+                std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
+                if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
+                    inp_sfx.erase(inp_sfx.begin());
+                }
                 inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
+                if (add_bos) {
+                    inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
+                }
                 inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
                 embd_inp = inp_pfx;
                 embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
index c53a64867336f9433417d215871f194fe2df4bd4..8c5318c650ae8e935420718fe3d656dff1349ad0 100644 (file)
@@ -344,9 +344,20 @@ struct llama_server_context
 
     void loadInfill()
     {
-        auto prefix_tokens = tokenize(params.input_prefix, true);  // always add BOS
-        auto suffix_tokens = tokenize(params.input_suffix, true);  // always add BOS
+        bool suff_rm_leading_spc = true;
+        if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
+            params.input_suffix.erase(0, 1);
+            suff_rm_leading_spc = false;
+        }
+
+        auto prefix_tokens = tokenize(params.input_prefix, false);
+        auto suffix_tokens = tokenize(params.input_suffix, false);
+        const int space_token = 29871;
+        if (suff_rm_leading_spc  && suffix_tokens[0] == space_token) {
+            suffix_tokens.erase(suffix_tokens.begin());
+        }
         prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
+        prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
         prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
         prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
         prefix_tokens.push_back(llama_token_middle(ctx));