]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama.android : fix build (#9350)
authorGeorgi Gerganov <redacted>
Sat, 7 Sep 2024 21:33:50 +0000 (00:33 +0300)
committerGitHub <redacted>
Sat, 7 Sep 2024 21:33:50 +0000 (00:33 +0300)
examples/llama.android/llama/src/main/cpp/llama-android.cpp
examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt

index 9217937512d75f731057e966ce2ffb1cbad4e4ae..06ec160c2994042454eb0552e83c8767b13fe8f0 100644 (file)
@@ -269,12 +269,6 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
     return env->NewStringUTF(result.str().c_str());
 }
 
-extern "C"
-JNIEXPORT void JNICALL
-Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
-    llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
-}
-
 extern "C"
 JNIEXPORT jlong JNICALL
 Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
@@ -311,6 +305,29 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
     return reinterpret_cast<jlong>(batch);
 }
 
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
+    llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
+}
+
+extern "C"
+JNIEXPORT jlong JNICALL
+Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
+    auto sparams = llama_sampler_chain_default_params();
+    sparams.no_perf = true;
+    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+    llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
+
+    return reinterpret_cast<jlong>(smpl);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
+    llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
+}
+
 extern "C"
 JNIEXPORT void JNICALL
 Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
@@ -380,14 +397,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
         JNIEnv * env,
         jobject,
         jlong context_pointer,
-        jlong sampling_pointer,
         jlong batch_pointer,
+        jlong sampler_pointer,
         jint n_len,
         jobject intvar_ncur
 ) {
     const auto context = reinterpret_cast<llama_context *>(context_pointer);
-    const auto sampling = reinterpret_cast<llama_sampler *>(sampling_pointer);
-    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
+    const auto batch   = reinterpret_cast<llama_batch   *>(batch_pointer);
+    const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
     const auto model = llama_get_model(context);
 
     if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
@@ -395,9 +412,9 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
     if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
 
     // sample the most likely token
-    const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1);
+    const auto new_token_id = llama_sampler_sample(sampler, context, -1);
 
-    llama_sampler_accept(sampling, new_token_id);
+    llama_sampler_accept(sampler, new_token_id);
 
     const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
     if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
index 6c63e54e0d9082d8e858a2e8b35f5464aae47998..cf520e4594004194f66d3ee27643c2a807a825a1 100644 (file)
@@ -45,8 +45,10 @@ class LLamaAndroid {
     private external fun free_context(context: Long)
     private external fun backend_init(numa: Boolean)
     private external fun backend_free()
-    private external fun free_batch(batch: Long)
     private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
+    private external fun free_batch(batch: Long)
+    private external fun new_sampler(): Long
+    private external fun free_sampler(sampler: Long)
     private external fun bench_model(
         context: Long,
         model: Long,
@@ -69,6 +71,7 @@ class LLamaAndroid {
     private external fun completion_loop(
         context: Long,
         batch: Long,
+        sampler: Long,
         nLen: Int,
         ncur: IntVar
     ): String?
@@ -101,8 +104,11 @@ class LLamaAndroid {
                     val batch = new_batch(512, 0, 1)
                     if (batch == 0L) throw IllegalStateException("new_batch() failed")
 
+                    val sampler = new_sampler()
+                    if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
+
                     Log.i(tag, "Loaded model $pathToModel")
-                    threadLocalState.set(State.Loaded(model, context, batch))
+                    threadLocalState.set(State.Loaded(model, context, batch, sampler))
                 }
                 else -> throw IllegalStateException("Model already loaded")
             }
@@ -114,7 +120,7 @@ class LLamaAndroid {
             is State.Loaded -> {
                 val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
                 while (ncur.value <= nlen) {
-                    val str = completion_loop(state.context, state.batch, nlen, ncur)
+                    val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
                     if (str == null) {
                         break
                     }
@@ -138,6 +144,7 @@ class LLamaAndroid {
                     free_context(state.context)
                     free_model(state.model)
                     free_batch(state.batch)
+                    free_sampler(state.sampler);
 
                     threadLocalState.set(State.Idle)
                 }
@@ -161,7 +168,7 @@ class LLamaAndroid {
 
         private sealed interface State {
             data object Idle: State
-            data class Loaded(val model: Long, val context: Long, val batch: Long): State
+            data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
         }
 
         // Enforce only one instance of Llm.