]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
android : fix utf8 decoding error (#5935)
authorDean <redacted>
Sun, 10 Mar 2024 20:03:17 +0000 (04:03 +0800)
committerGitHub <redacted>
Sun, 10 Mar 2024 20:03:17 +0000 (22:03 +0200)
* examples: fix utf8 decoding error

some models have a tokenizer that decodes an id into an incomplete utf8 sequence, need to validate and wait for next token
one example would be: https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q4_0.gguf and and an example of the token is 18137

* android : minor

---------

Co-authored-by: zhangfuwen <redacted>
Co-authored-by: Georgi Gerganov <redacted>
examples/llama.android/app/src/main/cpp/llama-android.cpp
examples/llama.android/app/src/main/java/com/example/llama/Llm.kt

index 2beb1e0d5321df1b9dda9f3c248f7ac8dd950ef9..ce8ab3b7094073542cabc0471f286751f9117df6 100644 (file)
@@ -33,6 +33,45 @@ jclass la_int_var;
 jmethodID la_int_var_value;
 jmethodID la_int_var_inc;
 
+std::string cached_token_chars;
+
+bool is_valid_utf8(const char * string) {
+    if (!string) {
+        return true;
+    }
+
+    const unsigned char * bytes = (const unsigned char *)string;
+    int num;
+
+    while (*bytes != 0x00) {
+        if ((*bytes & 0x80) == 0x00) {
+            // U+0000 to U+007F
+            num = 1;
+        } else if ((*bytes & 0xE0) == 0xC0) {
+            // U+0080 to U+07FF
+            num = 2;
+        } else if ((*bytes & 0xF0) == 0xE0) {
+            // U+0800 to U+FFFF
+            num = 3;
+        } else if ((*bytes & 0xF8) == 0xF0) {
+            // U+10000 to U+10FFFF
+            num = 4;
+        } else {
+            return false;
+        }
+
+        bytes += 1;
+        for (int i = 1; i < num; ++i) {
+            if ((*bytes & 0xC0) != 0x80) {
+                return false;
+            }
+            bytes += 1;
+        }
+    }
+
+    return true;
+}
+
 static void log_callback(ggml_log_level level, const char * fmt, void * data) {
     if (level == GGML_LOG_LEVEL_ERROR)     __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
     else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
@@ -295,6 +334,8 @@ Java_com_example_llama_Llm_completion_1init(
         jint n_len
     ) {
 
+    cached_token_chars.clear();
+
     const auto text = env->GetStringUTFChars(jtext, 0);
     const auto context = reinterpret_cast<llama_context *>(context_pointer);
     const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
@@ -372,8 +413,16 @@ Java_com_example_llama_Llm_completion_1loop(
     }
 
     auto new_token_chars = llama_token_to_piece(context, new_token_id);
-    LOGi("new_token_chars: `%s`", new_token_chars.c_str());
-    auto new_token = env->NewStringUTF(new_token_chars.c_str());
+    cached_token_chars += new_token_chars;
+
+    jstring new_token = nullptr;
+    if (is_valid_utf8(cached_token_chars.c_str())) {
+        new_token = env->NewStringUTF(cached_token_chars.c_str());
+        LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
+        cached_token_chars.clear();
+    } else {
+        new_token = env->NewStringUTF("");
+    }
 
     llama_batch_clear(*batch);
     llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
index 5f32703724a493c504d56ff0b8cf6aad332587fb..d86afee379083610126988b6d3f18b5fa3bb5818 100644 (file)
@@ -71,7 +71,7 @@ class Llm {
         batch: Long,
         nLen: Int,
         ncur: IntVar
-    ): String
+    ): String?
 
     private external fun kv_cache_clear(context: Long)
 
@@ -115,7 +115,7 @@ class Llm {
                 val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
                 while (ncur.value <= nlen) {
                     val str = completion_loop(state.context, state.batch, nlen, ncur)
-                    if (str.isEmpty()) {
+                    if (str == null) {
                         break
                     }
                     emit(str)