]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama.android: add field formatChat to control whether to parse special tokens when...
authorcodezjx <redacted>
Fri, 17 Jan 2025 12:57:56 +0000 (20:57 +0800)
committerGitHub <redacted>
Fri, 17 Jan 2025 12:57:56 +0000 (14:57 +0200)
examples/llama.android/llama/src/main/cpp/llama-android.cpp
examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt

index 99b14961d5ac1d3027a0a208bc480d4af60d573b..2a73983a9832fb97961d32d44b9a894ccf5843d1 100644 (file)
@@ -347,6 +347,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
         jlong context_pointer,
         jlong batch_pointer,
         jstring jtext,
+        jboolean format_chat,
         jint n_len
     ) {
 
@@ -356,7 +357,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
     const auto context = reinterpret_cast<llama_context *>(context_pointer);
     const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
 
-    const auto tokens_list = common_tokenize(context, text, 1);
+    bool parse_special = (format_chat == JNI_TRUE);
+    const auto tokens_list = common_tokenize(context, text, true, parse_special);
 
     auto n_ctx = llama_n_ctx(context);
     auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
@@ -368,7 +370,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
     }
 
     for (auto id : tokens_list) {
-        LOGi("%s", common_token_to_piece(context, id).c_str());
+        LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
     }
 
     common_batch_clear(*batch);
index cf520e4594004194f66d3ee27643c2a807a825a1..b964d93e37819fd8968391e93b731c2f3015ae5a 100644 (file)
@@ -65,6 +65,7 @@ class LLamaAndroid {
         context: Long,
         batch: Long,
         text: String,
+        formatChat: Boolean,
         nLen: Int
     ): Int
 
@@ -115,10 +116,10 @@ class LLamaAndroid {
         }
     }
 
-    fun send(message: String): Flow<String> = flow {
+    fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
         when (val state = threadLocalState.get()) {
             is State.Loaded -> {
-                val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
+                val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
                 while (ncur.value <= nlen) {
                     val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
                     if (str == null) {