]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix handling of characters that span multiple tokens when streaming (#4446)
authorshibe2 <redacted>
Wed, 13 Dec 2023 19:57:15 +0000 (23:57 +0400)
committerGitHub <redacted>
Wed, 13 Dec 2023 19:57:15 +0000 (21:57 +0200)
examples/server/server.cpp

index d0cd8e1cdb2112889fd95b0c4f0bb987ea09e214..39d1e83d1857512a6ef3efb9c874661180b04c15 100644 (file)
@@ -376,7 +376,6 @@ struct llama_client_slot
 
     int32_t num_prompt_tokens           = 0;
     int32_t num_prompt_tokens_processed = 0;
-    int32_t multibyte_pending           = 0;
 
     json prompt;
     std::string generated_text;
@@ -425,7 +424,6 @@ struct llama_client_slot
         stopped_word           = false;
         stopped_limit          = false;
         stopping_word          = "";
-        multibyte_pending      = 0;
         n_past                 = 0;
         sent_count             = 0;
         sent_token_probs_index = 0;
@@ -992,35 +990,36 @@ struct llama_server_context
         slot.generated_text += token_str;
         slot.has_next_token = true;
 
-        if (slot.multibyte_pending > 0)
+        // check if there is incomplete UTF-8 character at the end
+        bool incomplete = false;
+        for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
         {
-            slot.multibyte_pending -= token_str.size();
-        }
-        else if (token_str.size() == 1)
-        {
-            const char c = token_str[0];
-            // 2-byte characters: 110xxxxx 10xxxxxx
+            unsigned char c = slot.generated_text[slot.generated_text.size() - i];
+            if ((c & 0xC0) == 0x80)
+            {
+                // continuation byte: 10xxxxxx
+                continue;
+            }
             if ((c & 0xE0) == 0xC0)
             {
-                slot.multibyte_pending = 1;
-                // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
+                // 2-byte character: 110xxxxx ...
+                incomplete = i < 2;
             }
             else if ((c & 0xF0) == 0xE0)
             {
-                slot.multibyte_pending = 2;
-                // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
+                // 3-byte character: 1110xxxx ...
+                incomplete = i < 3;
             }
             else if ((c & 0xF8) == 0xF0)
             {
-                slot.multibyte_pending = 3;
-            }
-            else
-            {
-                slot.multibyte_pending = 0;
+                // 4-byte character: 11110xxx ...
+                incomplete = i < 4;
             }
+            // else 1-byte character or invalid byte
+            break;
         }
 
-        if (slot.multibyte_pending == 0)
+        if (!incomplete)
         {
             size_t pos = std::min(slot.sent_count, slot.generated_text.size());
             const std::string str_test = slot.generated_text.substr(pos);
@@ -1055,7 +1054,7 @@ struct llama_server_context
             }
         }
 
-        if (slot.multibyte_pending > 0 && !slot.has_next_token)
+        if (incomplete)
         {
             slot.has_next_token = true;
         }