namespace = "com.example.llama"
compileSdk = 34
- ndkVersion = "26.1.10909125"
-
defaultConfig {
applicationId = "com.example.llama"
minSdk = 33
vectorDrawables {
useSupportLibrary = true
}
- ndk {
- // Add NDK properties if wanted, e.g.
- // abiFilters += listOf("arm64-v8a")
- }
- externalNativeBuild {
- cmake {
- arguments += "-DCMAKE_BUILD_TYPE=Release"
- cppFlags += listOf()
- arguments += listOf()
- }
- }
}
buildTypes {
composeOptions {
kotlinCompilerExtensionVersion = "1.5.1"
}
- packaging {
- resources {
- excludes += "/META-INF/{AL2.0,LGPL2.1}"
- }
- }
- externalNativeBuild {
- cmake {
- path = file("src/main/cpp/CMakeLists.txt")
- version = "3.22.1"
- }
- }
}
dependencies {
implementation("androidx.compose.ui:ui-graphics")
implementation("androidx.compose.ui:ui-tooling-preview")
implementation("androidx.compose.material3:material3")
+ implementation(project(":llama"))
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
+++ /dev/null
-
-# For more information about using CMake with Android Studio, read the
-# documentation: https://d.android.com/studio/projects/add-native-code.html.
-# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
-
-# Sets the minimum CMake version required for this project.
-cmake_minimum_required(VERSION 3.22.1)
-
-# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
-# Since this is the top level CMakeLists.txt, the project name is also accessible
-# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
-# build script scope).
-project("llama-android")
-
-## Fetch latest llama.cpp from GitHub
-#include(FetchContent)
-#FetchContent_Declare(
-# llama
-# GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
-# GIT_TAG master
-#)
-#
-## Also provides "common"
-#FetchContent_MakeAvailable(llama)
-
-# llama.cpp CI uses the code from the current branch
-# ref: https://github.com/ggerganov/llama.cpp/pull/7341#issuecomment-2117617700
-add_subdirectory(../../../../../../ build-llama)
-
-# Creates and names a library, sets it as either STATIC
-# or SHARED, and provides the relative paths to its source code.
-# You can define multiple libraries, and CMake builds them for you.
-# Gradle automatically packages shared libraries with your APK.
-#
-# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
-# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
-# is preferred for the same purpose.
-#
-# In order to load a library into your app from Java/Kotlin, you must call
-# System.loadLibrary() and pass the name of the library defined here;
-# for GameActivity/NativeActivity derived applications, the same library name must be
-# used in the AndroidManifest.xml file.
-add_library(${CMAKE_PROJECT_NAME} SHARED
- # List C/C++ source files with relative paths to this CMakeLists.txt.
- llama-android.cpp)
-
-# Specifies libraries CMake should link to your target library. You
-# can link libraries from various origins, such as libraries defined in this
-# build script, prebuilt third-party libraries, or Android system libraries.
-target_link_libraries(${CMAKE_PROJECT_NAME}
- # List libraries link to the target library
- llama
- common
- android
- log)
+++ /dev/null
-#include <android/log.h>
-#include <jni.h>
-#include <iomanip>
-#include <math.h>
-#include <string>
-#include <unistd.h>
-#include "llama.h"
-#include "common/common.h"
-
-// Write C++ code here.
-//
-// Do not forget to dynamically load the C++ library into your application.
-//
-// For instance,
-//
-// In MainActivity.java:
-// static {
-// System.loadLibrary("llama-android");
-// }
-//
-// Or, in MainActivity.kt:
-// companion object {
-// init {
-// System.loadLibrary("llama-android")
-// }
-// }
-
-#define TAG "llama-android.cpp"
-#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
-#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
-
-jclass la_int_var;
-jmethodID la_int_var_value;
-jmethodID la_int_var_inc;
-
-std::string cached_token_chars;
-
-bool is_valid_utf8(const char * string) {
- if (!string) {
- return true;
- }
-
- const unsigned char * bytes = (const unsigned char *)string;
- int num;
-
- while (*bytes != 0x00) {
- if ((*bytes & 0x80) == 0x00) {
- // U+0000 to U+007F
- num = 1;
- } else if ((*bytes & 0xE0) == 0xC0) {
- // U+0080 to U+07FF
- num = 2;
- } else if ((*bytes & 0xF0) == 0xE0) {
- // U+0800 to U+FFFF
- num = 3;
- } else if ((*bytes & 0xF8) == 0xF0) {
- // U+10000 to U+10FFFF
- num = 4;
- } else {
- return false;
- }
-
- bytes += 1;
- for (int i = 1; i < num; ++i) {
- if ((*bytes & 0xC0) != 0x80) {
- return false;
- }
- bytes += 1;
- }
- }
-
- return true;
-}
-
-static void log_callback(ggml_log_level level, const char * fmt, void * data) {
- if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
- else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
- else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
- else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
-}
-
-extern "C"
-JNIEXPORT jlong JNICALL
-Java_com_example_llama_Llm_load_1model(JNIEnv *env, jobject, jstring filename) {
- llama_model_params model_params = llama_model_default_params();
-
- auto path_to_model = env->GetStringUTFChars(filename, 0);
- LOGi("Loading model from %s", path_to_model);
-
- auto model = llama_load_model_from_file(path_to_model, model_params);
- env->ReleaseStringUTFChars(filename, path_to_model);
-
- if (!model) {
- LOGe("load_model() failed");
- env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
- return 0;
- }
-
- return reinterpret_cast<jlong>(model);
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_free_1model(JNIEnv *, jobject, jlong model) {
- llama_free_model(reinterpret_cast<llama_model *>(model));
-}
-
-extern "C"
-JNIEXPORT jlong JNICALL
-Java_com_example_llama_Llm_new_1context(JNIEnv *env, jobject, jlong jmodel) {
- auto model = reinterpret_cast<llama_model *>(jmodel);
-
- if (!model) {
- LOGe("new_context(): model cannot be null");
- env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
- return 0;
- }
-
- int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
- LOGi("Using %d threads", n_threads);
-
- llama_context_params ctx_params = llama_context_default_params();
- ctx_params.seed = 1234;
- ctx_params.n_ctx = 2048;
- ctx_params.n_threads = n_threads;
- ctx_params.n_threads_batch = n_threads;
-
- llama_context * context = llama_new_context_with_model(model, ctx_params);
-
- if (!context) {
- LOGe("llama_new_context_with_model() returned null)");
- env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
- "llama_new_context_with_model() returned null)");
- return 0;
- }
-
- return reinterpret_cast<jlong>(context);
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_free_1context(JNIEnv *, jobject, jlong context) {
- llama_free(reinterpret_cast<llama_context *>(context));
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_backend_1free(JNIEnv *, jobject) {
- llama_backend_free();
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_log_1to_1android(JNIEnv *, jobject) {
- llama_log_set(log_callback, NULL);
-}
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_com_example_llama_Llm_bench_1model(
- JNIEnv *env,
- jobject,
- jlong context_pointer,
- jlong model_pointer,
- jlong batch_pointer,
- jint pp,
- jint tg,
- jint pl,
- jint nr
- ) {
- auto pp_avg = 0.0;
- auto tg_avg = 0.0;
- auto pp_std = 0.0;
- auto tg_std = 0.0;
-
- const auto context = reinterpret_cast<llama_context *>(context_pointer);
- const auto model = reinterpret_cast<llama_model *>(model_pointer);
- const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
-
- const int n_ctx = llama_n_ctx(context);
-
- LOGi("n_ctx = %d", n_ctx);
-
- int i, j;
- int nri;
- for (nri = 0; nri < nr; nri++) {
- LOGi("Benchmark prompt processing (pp)");
-
- llama_batch_clear(*batch);
-
- const int n_tokens = pp;
- for (i = 0; i < n_tokens; i++) {
- llama_batch_add(*batch, 0, i, { 0 }, false);
- }
-
- batch->logits[batch->n_tokens - 1] = true;
- llama_kv_cache_clear(context);
-
- const auto t_pp_start = ggml_time_us();
- if (llama_decode(context, *batch) != 0) {
- LOGi("llama_decode() failed during prompt processing");
- }
- const auto t_pp_end = ggml_time_us();
-
- // bench text generation
-
- LOGi("Benchmark text generation (tg)");
-
- llama_kv_cache_clear(context);
- const auto t_tg_start = ggml_time_us();
- for (i = 0; i < tg; i++) {
-
- llama_batch_clear(*batch);
- for (j = 0; j < pl; j++) {
- llama_batch_add(*batch, 0, i, { j }, true);
- }
-
- LOGi("llama_decode() text generation: %d", i);
- if (llama_decode(context, *batch) != 0) {
- LOGi("llama_decode() failed during text generation");
- }
- }
-
- const auto t_tg_end = ggml_time_us();
-
- llama_kv_cache_clear(context);
-
- const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
- const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
-
- const auto speed_pp = double(pp) / t_pp;
- const auto speed_tg = double(pl * tg) / t_tg;
-
- pp_avg += speed_pp;
- tg_avg += speed_tg;
-
- pp_std += speed_pp * speed_pp;
- tg_std += speed_tg * speed_tg;
-
- LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
- }
-
- pp_avg /= double(nr);
- tg_avg /= double(nr);
-
- if (nr > 1) {
- pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
- tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
- } else {
- pp_std = 0;
- tg_std = 0;
- }
-
- char model_desc[128];
- llama_model_desc(model, model_desc, sizeof(model_desc));
-
- const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
- const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
-
- const auto backend = "(Android)"; // TODO: What should this be?
-
- std::stringstream result;
- result << std::setprecision(2);
- result << "| model | size | params | backend | test | t/s |\n";
- result << "| --- | --- | --- | --- | --- | --- |\n";
- result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
- result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
-
- return env->NewStringUTF(result.str().c_str());
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
- llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
-}
-
-extern "C"
-JNIEXPORT jlong JNICALL
-Java_com_example_llama_Llm_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
-
- // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
-
- llama_batch *batch = new llama_batch {
- 0,
- nullptr,
- nullptr,
- nullptr,
- nullptr,
- nullptr,
- nullptr,
- 0,
- 0,
- 0,
- };
-
- if (embd) {
- batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
- } else {
- batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
- }
-
- batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
- batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
- batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
- for (int i = 0; i < n_tokens; ++i) {
- batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
- }
- batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
-
- return reinterpret_cast<jlong>(batch);
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_backend_1init(JNIEnv *, jobject) {
- llama_backend_init();
-}
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_com_example_llama_Llm_system_1info(JNIEnv *env, jobject) {
- return env->NewStringUTF(llama_print_system_info());
-}
-
-extern "C"
-JNIEXPORT jint JNICALL
-Java_com_example_llama_Llm_completion_1init(
- JNIEnv *env,
- jobject,
- jlong context_pointer,
- jlong batch_pointer,
- jstring jtext,
- jint n_len
- ) {
-
- cached_token_chars.clear();
-
- const auto text = env->GetStringUTFChars(jtext, 0);
- const auto context = reinterpret_cast<llama_context *>(context_pointer);
- const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
-
- const auto tokens_list = llama_tokenize(context, text, 1);
-
- auto n_ctx = llama_n_ctx(context);
- auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
-
- LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
-
- if (n_kv_req > n_ctx) {
- LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
- }
-
- for (auto id : tokens_list) {
- LOGi("%s", llama_token_to_piece(context, id).c_str());
- }
-
- llama_batch_clear(*batch);
-
- // evaluate the initial prompt
- for (auto i = 0; i < tokens_list.size(); i++) {
- llama_batch_add(*batch, tokens_list[i], i, { 0 }, false);
- }
-
- // llama_decode will output logits only for the last token of the prompt
- batch->logits[batch->n_tokens - 1] = true;
-
- if (llama_decode(context, *batch) != 0) {
- LOGe("llama_decode() failed");
- }
-
- env->ReleaseStringUTFChars(jtext, text);
-
- return batch->n_tokens;
-}
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_com_example_llama_Llm_completion_1loop(
- JNIEnv * env,
- jobject,
- jlong context_pointer,
- jlong batch_pointer,
- jint n_len,
- jobject intvar_ncur
-) {
- const auto context = reinterpret_cast<llama_context *>(context_pointer);
- const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
- const auto model = llama_get_model(context);
-
- if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
- if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
- if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
-
- auto n_vocab = llama_n_vocab(model);
- auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
-
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
-
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
- }
-
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
- // sample the most likely token
- const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
-
- const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
- if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
- return env->NewStringUTF("");
- }
-
- auto new_token_chars = llama_token_to_piece(context, new_token_id);
- cached_token_chars += new_token_chars;
-
- jstring new_token = nullptr;
- if (is_valid_utf8(cached_token_chars.c_str())) {
- new_token = env->NewStringUTF(cached_token_chars.c_str());
- LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
- cached_token_chars.clear();
- } else {
- new_token = env->NewStringUTF("");
- }
-
- llama_batch_clear(*batch);
- llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
-
- env->CallVoidMethod(intvar_ncur, la_int_var_inc);
-
- if (llama_decode(context, *batch) != 0) {
- LOGe("llama_decode() returned null");
- }
-
- return new_token;
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
- llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
-}
+++ /dev/null
-package com.example.llama
-
-import android.util.Log
-import kotlinx.coroutines.CoroutineDispatcher
-import kotlinx.coroutines.asCoroutineDispatcher
-import kotlinx.coroutines.flow.Flow
-import kotlinx.coroutines.flow.flow
-import kotlinx.coroutines.flow.flowOn
-import kotlinx.coroutines.withContext
-import java.util.concurrent.Executors
-import kotlin.concurrent.thread
-
-class Llm {
- private val tag: String? = this::class.simpleName
-
- private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
-
- private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
- thread(start = false, name = "Llm-RunLoop") {
- Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
-
- // No-op if called more than once.
- System.loadLibrary("llama-android")
-
- // Set llama log handler to Android
- log_to_android()
- backend_init(false)
-
- Log.d(tag, system_info())
-
- it.run()
- }.apply {
- uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
- Log.e(tag, "Unhandled exception", exception)
- }
- }
- }.asCoroutineDispatcher()
-
- private val nlen: Int = 64
-
- private external fun log_to_android()
- private external fun load_model(filename: String): Long
- private external fun free_model(model: Long)
- private external fun new_context(model: Long): Long
- private external fun free_context(context: Long)
- private external fun backend_init(numa: Boolean)
- private external fun backend_free()
- private external fun free_batch(batch: Long)
- private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
- private external fun bench_model(
- context: Long,
- model: Long,
- batch: Long,
- pp: Int,
- tg: Int,
- pl: Int,
- nr: Int
- ): String
-
- private external fun system_info(): String
-
- private external fun completion_init(
- context: Long,
- batch: Long,
- text: String,
- nLen: Int
- ): Int
-
- private external fun completion_loop(
- context: Long,
- batch: Long,
- nLen: Int,
- ncur: IntVar
- ): String?
-
- private external fun kv_cache_clear(context: Long)
-
- suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
- return withContext(runLoop) {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- Log.d(tag, "bench(): $state")
- bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
- }
-
- else -> throw IllegalStateException("No model loaded")
- }
- }
- }
-
- suspend fun load(pathToModel: String) {
- withContext(runLoop) {
- when (threadLocalState.get()) {
- is State.Idle -> {
- val model = load_model(pathToModel)
- if (model == 0L) throw IllegalStateException("load_model() failed")
-
- val context = new_context(model)
- if (context == 0L) throw IllegalStateException("new_context() failed")
-
- val batch = new_batch(512, 0, 1)
- if (batch == 0L) throw IllegalStateException("new_batch() failed")
-
- Log.i(tag, "Loaded model $pathToModel")
- threadLocalState.set(State.Loaded(model, context, batch))
- }
- else -> throw IllegalStateException("Model already loaded")
- }
- }
- }
-
- fun send(message: String): Flow<String> = flow {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
- while (ncur.value <= nlen) {
- val str = completion_loop(state.context, state.batch, nlen, ncur)
- if (str == null) {
- break
- }
- emit(str)
- }
- kv_cache_clear(state.context)
- }
- else -> {}
- }
- }.flowOn(runLoop)
-
- /**
- * Unloads the model and frees resources.
- *
- * This is a no-op if there's no model loaded.
- */
- suspend fun unload() {
- withContext(runLoop) {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- free_context(state.context)
- free_model(state.model)
- free_batch(state.batch)
-
- threadLocalState.set(State.Idle)
- }
- else -> {}
- }
- }
- }
-
- companion object {
- private class IntVar(value: Int) {
- @Volatile
- var value: Int = value
- private set
-
- fun inc() {
- synchronized(this) {
- value += 1
- }
- }
- }
-
- private sealed interface State {
- data object Idle: State
- data class Loaded(val model: Long, val context: Long, val batch: Long): State
- }
-
- // Enforce only one instance of Llm.
- private val _instance: Llm = Llm()
-
- fun instance(): Llm = _instance
- }
-}
package com.example.llama
+import android.llama.cpp.LLamaAndroid
import android.util.Log
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch
-class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
+class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() {
companion object {
@JvmStatic
private val NanosPerSecond = 1_000_000_000.0
viewModelScope.launch {
try {
- llm.unload()
+ llamaAndroid.unload()
} catch (exc: IllegalStateException) {
messages += exc.message!!
}
messages += ""
viewModelScope.launch {
- llm.send(text)
+ llamaAndroid.send(text)
.catch {
Log.e(tag, "send() failed", it)
messages += it.message!!
viewModelScope.launch {
try {
val start = System.nanoTime()
- val warmupResult = llm.bench(pp, tg, pl, nr)
+ val warmupResult = llamaAndroid.bench(pp, tg, pl, nr)
val end = System.nanoTime()
messages += warmupResult
return@launch
}
- messages += llm.bench(512, 128, 1, 3)
+ messages += llamaAndroid.bench(512, 128, 1, 3)
} catch (exc: IllegalStateException) {
Log.e(tag, "bench() failed", exc)
messages += exc.message!!
fun load(pathToModel: String) {
viewModelScope.launch {
try {
- llm.load(pathToModel)
+ llamaAndroid.load(pathToModel)
messages += "Loaded $pathToModel"
} catch (exc: IllegalStateException) {
Log.e(tag, "load() failed", exc)
plugins {
id("com.android.application") version "8.2.0" apply false
id("org.jetbrains.kotlin.android") version "1.9.0" apply false
+ id("com.android.library") version "8.2.0" apply false
}
--- /dev/null
+
+# For more information about using CMake with Android Studio, read the
+# documentation: https://d.android.com/studio/projects/add-native-code.html.
+# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
+
+# Sets the minimum CMake version required for this project.
+cmake_minimum_required(VERSION 3.22.1)
+
+# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
+# Since this is the top level CMakeLists.txt, the project name is also accessible
+# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
+# build script scope).
+project("llama-android")
+
+## Fetch latest llama.cpp from GitHub
+#include(FetchContent)
+#FetchContent_Declare(
+# llama
+# GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
+# GIT_TAG master
+#)
+#
+## Also provides "common"
+#FetchContent_MakeAvailable(llama)
+
+# llama.cpp CI uses the code from the current branch
+# ref: https://github.com/ggerganov/llama.cpp/pull/7341#issuecomment-2117617700
+add_subdirectory(../../../../../../ build-llama)
+
+# Creates and names a library, sets it as either STATIC
+# or SHARED, and provides the relative paths to its source code.
+# You can define multiple libraries, and CMake builds them for you.
+# Gradle automatically packages shared libraries with your APK.
+#
+# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
+# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
+# is preferred for the same purpose.
+#
+# In order to load a library into your app from Java/Kotlin, you must call
+# System.loadLibrary() and pass the name of the library defined here;
+# for GameActivity/NativeActivity derived applications, the same library name must be
+# used in the AndroidManifest.xml file.
+add_library(${CMAKE_PROJECT_NAME} SHARED
+ # List C/C++ source files with relative paths to this CMakeLists.txt.
+ llama-android.cpp)
+
+# Specifies libraries CMake should link to your target library. You
+# can link libraries from various origins, such as libraries defined in this
+# build script, prebuilt third-party libraries, or Android system libraries.
+target_link_libraries(${CMAKE_PROJECT_NAME}
+ # List libraries link to the target library
+ llama
+ common
+ android
+ log)
--- /dev/null
+plugins {
+ id("com.android.library")
+ id("org.jetbrains.kotlin.android")
+}
+
+android {
+ namespace = "android.llama.cpp"
+ compileSdk = 34
+
+ defaultConfig {
+ minSdk = 33
+
+ testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
+ consumerProguardFiles("consumer-rules.pro")
+ ndk {
+ // Add NDK properties if wanted, e.g.
+ // abiFilters += listOf("arm64-v8a")
+ }
+ externalNativeBuild {
+ cmake {
+ arguments += "-DCMAKE_BUILD_TYPE=Release"
+ cppFlags += listOf()
+ arguments += listOf()
+
+ cppFlags("")
+ }
+ }
+ }
+
+ buildTypes {
+ release {
+ isMinifyEnabled = false
+ proguardFiles(
+ getDefaultProguardFile("proguard-android-optimize.txt"),
+ "proguard-rules.pro"
+ )
+ }
+ }
+ externalNativeBuild {
+ cmake {
+ path("src/main/cpp/CMakeLists.txt")
+ version = "3.22.1"
+ }
+ }
+ compileOptions {
+ sourceCompatibility = JavaVersion.VERSION_1_8
+ targetCompatibility = JavaVersion.VERSION_1_8
+ }
+ kotlinOptions {
+ jvmTarget = "1.8"
+ }
+
+ packaging {
+ resources {
+ excludes += "/META-INF/{AL2.0,LGPL2.1}"
+ }
+ }
+}
+
+dependencies {
+
+ implementation("androidx.core:core-ktx:1.12.0")
+ implementation("androidx.appcompat:appcompat:1.6.1")
+ implementation("com.google.android.material:material:1.11.0")
+ testImplementation("junit:junit:4.13.2")
+ androidTestImplementation("androidx.test.ext:junit:1.1.5")
+ androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
+}
--- /dev/null
+# Add project specific ProGuard rules here.
+# You can control the set of applied configuration files using the
+# proguardFiles setting in build.gradle.
+#
+# For more details, see
+# http://developer.android.com/guide/developing/tools/proguard.html
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+# public *;
+#}
+
+# Uncomment this to preserve the line number information for
+# debugging stack traces.
+#-keepattributes SourceFile,LineNumberTable
+
+# If you keep the line number information, uncomment this to
+# hide the original source file name.
+#-renamesourcefileattribute SourceFile
--- /dev/null
+package android.llama.cpp
+
+import androidx.test.platform.app.InstrumentationRegistry
+import androidx.test.ext.junit.runners.AndroidJUnit4
+
+import org.junit.Test
+import org.junit.runner.RunWith
+
+import org.junit.Assert.*
+
+/**
+ * Instrumented test, which will execute on an Android device.
+ *
+ * See [testing documentation](http://d.android.com/tools/testing).
+ */
+@RunWith(AndroidJUnit4::class)
+class ExampleInstrumentedTest {
+ @Test
+ fun useAppContext() {
+ // Context of the app under test.
+ val appContext = InstrumentationRegistry.getInstrumentation().targetContext
+ assertEquals("android.llama.cpp.test", appContext.packageName)
+ }
+}
--- /dev/null
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android">
+
+</manifest>
--- /dev/null
+# For more information about using CMake with Android Studio, read the
+# documentation: https://d.android.com/studio/projects/add-native-code.html.
+# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
+
+# Sets the minimum CMake version required for this project.
+cmake_minimum_required(VERSION 3.22.1)
+
+# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
+# Since this is the top level CMakeLists.txt, the project name is also accessible
+# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
+# build script scope).
+project("llama-android")
+
+include(FetchContent)
+FetchContent_Declare(
+ llama
+ GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
+ GIT_TAG master
+)
+
+# Also provides "common"
+FetchContent_MakeAvailable(llama)
+
+# Creates and names a library, sets it as either STATIC
+# or SHARED, and provides the relative paths to its source code.
+# You can define multiple libraries, and CMake builds them for you.
+# Gradle automatically packages shared libraries with your APK.
+#
+# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
+# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
+# is preferred for the same purpose.
+#
+# In order to load a library into your app from Java/Kotlin, you must call
+# System.loadLibrary() and pass the name of the library defined here;
+# for GameActivity/NativeActivity derived applications, the same library name must be
+# used in the AndroidManifest.xml file.
+add_library(${CMAKE_PROJECT_NAME} SHARED
+ # List C/C++ source files with relative paths to this CMakeLists.txt.
+ llama-android.cpp)
+
+# Specifies libraries CMake should link to your target library. You
+# can link libraries from various origins, such as libraries defined in this
+# build script, prebuilt third-party libraries, or Android system libraries.
+target_link_libraries(${CMAKE_PROJECT_NAME}
+ # List libraries link to the target library
+ llama
+ common
+ android
+ log)
--- /dev/null
+#include <android/log.h>
+#include <jni.h>
+#include <iomanip>
+#include <math.h>
+#include <string>
+#include <unistd.h>
+#include "llama.h"
+#include "common/common.h"
+
+// Write C++ code here.
+//
+// Do not forget to dynamically load the C++ library into your application.
+//
+// For instance,
+//
+// In MainActivity.java:
+// static {
+// System.loadLibrary("llama-android");
+// }
+//
+// Or, in MainActivity.kt:
+// companion object {
+// init {
+// System.loadLibrary("llama-android")
+// }
+// }
+
+#define TAG "llama-android.cpp"
+#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
+#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
+
+jclass la_int_var;
+jmethodID la_int_var_value;
+jmethodID la_int_var_inc;
+
+std::string cached_token_chars;
+
+bool is_valid_utf8(const char * string) {
+ if (!string) {
+ return true;
+ }
+
+ const unsigned char * bytes = (const unsigned char *)string;
+ int num;
+
+ while (*bytes != 0x00) {
+ if ((*bytes & 0x80) == 0x00) {
+ // U+0000 to U+007F
+ num = 1;
+ } else if ((*bytes & 0xE0) == 0xC0) {
+ // U+0080 to U+07FF
+ num = 2;
+ } else if ((*bytes & 0xF0) == 0xE0) {
+ // U+0800 to U+FFFF
+ num = 3;
+ } else if ((*bytes & 0xF8) == 0xF0) {
+ // U+10000 to U+10FFFF
+ num = 4;
+ } else {
+ return false;
+ }
+
+ bytes += 1;
+ for (int i = 1; i < num; ++i) {
+ if ((*bytes & 0xC0) != 0x80) {
+ return false;
+ }
+ bytes += 1;
+ }
+ }
+
+ return true;
+}
+
+static void log_callback(ggml_log_level level, const char * fmt, void * data) {
+ if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
+ else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
+ else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
+ else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
+}
+
+extern "C"
+JNIEXPORT jlong JNICALL
+Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
+ llama_model_params model_params = llama_model_default_params();
+
+ auto path_to_model = env->GetStringUTFChars(filename, 0);
+ LOGi("Loading model from %s", path_to_model);
+
+ auto model = llama_load_model_from_file(path_to_model, model_params);
+ env->ReleaseStringUTFChars(filename, path_to_model);
+
+ if (!model) {
+ LOGe("load_model() failed");
+ env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
+ return 0;
+ }
+
+ return reinterpret_cast<jlong>(model);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
+ llama_free_model(reinterpret_cast<llama_model *>(model));
+}
+
+extern "C"
+JNIEXPORT jlong JNICALL
+Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) {
+ auto model = reinterpret_cast<llama_model *>(jmodel);
+
+ if (!model) {
+ LOGe("new_context(): model cannot be null");
+ env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
+ return 0;
+ }
+
+ int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
+ LOGi("Using %d threads", n_threads);
+
+ llama_context_params ctx_params = llama_context_default_params();
+ ctx_params.seed = 1234;
+ ctx_params.n_ctx = 2048;
+ ctx_params.n_threads = n_threads;
+ ctx_params.n_threads_batch = n_threads;
+
+ llama_context * context = llama_new_context_with_model(model, ctx_params);
+
+ if (!context) {
+ LOGe("llama_new_context_with_model() returned null)");
+ env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
+ "llama_new_context_with_model() returned null)");
+ return 0;
+ }
+
+ return reinterpret_cast<jlong>(context);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) {
+ llama_free(reinterpret_cast<llama_context *>(context));
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
+ llama_backend_free();
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
+ llama_log_set(log_callback, NULL);
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_android_llama_cpp_LLamaAndroid_bench_1model(
+ JNIEnv *env,
+ jobject,
+ jlong context_pointer,
+ jlong model_pointer,
+ jlong batch_pointer,
+ jint pp,
+ jint tg,
+ jint pl,
+ jint nr
+ ) {
+ auto pp_avg = 0.0;
+ auto tg_avg = 0.0;
+ auto pp_std = 0.0;
+ auto tg_std = 0.0;
+
+ const auto context = reinterpret_cast<llama_context *>(context_pointer);
+ const auto model = reinterpret_cast<llama_model *>(model_pointer);
+ const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
+
+ const int n_ctx = llama_n_ctx(context);
+
+ LOGi("n_ctx = %d", n_ctx);
+
+ int i, j;
+ int nri;
+ for (nri = 0; nri < nr; nri++) {
+ LOGi("Benchmark prompt processing (pp)");
+
+ llama_batch_clear(*batch);
+
+ const int n_tokens = pp;
+ for (i = 0; i < n_tokens; i++) {
+ llama_batch_add(*batch, 0, i, { 0 }, false);
+ }
+
+ batch->logits[batch->n_tokens - 1] = true;
+ llama_kv_cache_clear(context);
+
+ const auto t_pp_start = ggml_time_us();
+ if (llama_decode(context, *batch) != 0) {
+ LOGi("llama_decode() failed during prompt processing");
+ }
+ const auto t_pp_end = ggml_time_us();
+
+ // bench text generation
+
+ LOGi("Benchmark text generation (tg)");
+
+ llama_kv_cache_clear(context);
+ const auto t_tg_start = ggml_time_us();
+ for (i = 0; i < tg; i++) {
+
+ llama_batch_clear(*batch);
+ for (j = 0; j < pl; j++) {
+ llama_batch_add(*batch, 0, i, { j }, true);
+ }
+
+ LOGi("llama_decode() text generation: %d", i);
+ if (llama_decode(context, *batch) != 0) {
+ LOGi("llama_decode() failed during text generation");
+ }
+ }
+
+ const auto t_tg_end = ggml_time_us();
+
+ llama_kv_cache_clear(context);
+
+ const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
+ const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
+
+ const auto speed_pp = double(pp) / t_pp;
+ const auto speed_tg = double(pl * tg) / t_tg;
+
+ pp_avg += speed_pp;
+ tg_avg += speed_tg;
+
+ pp_std += speed_pp * speed_pp;
+ tg_std += speed_tg * speed_tg;
+
+ LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
+ }
+
+ pp_avg /= double(nr);
+ tg_avg /= double(nr);
+
+ if (nr > 1) {
+ pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
+ tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
+ } else {
+ pp_std = 0;
+ tg_std = 0;
+ }
+
+ char model_desc[128];
+ llama_model_desc(model, model_desc, sizeof(model_desc));
+
+ const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
+ const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
+
+ const auto backend = "(Android)"; // TODO: What should this be?
+
+ std::stringstream result;
+ result << std::setprecision(2);
+ result << "| model | size | params | backend | test | t/s |\n";
+ result << "| --- | --- | --- | --- | --- | --- |\n";
+ result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
+ result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
+
+ return env->NewStringUTF(result.str().c_str());
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
+ llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
+}
+
+extern "C"
+JNIEXPORT jlong JNICALL
+Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
+
+ // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
+
+ llama_batch *batch = new llama_batch {
+ 0,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ 0,
+ 0,
+ 0,
+ };
+
+ if (embd) {
+ batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
+ } else {
+ batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
+ }
+
+ batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
+ batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
+ batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
+ for (int i = 0; i < n_tokens; ++i) {
+ batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
+ }
+ batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
+
+ return reinterpret_cast<jlong>(batch);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
+ llama_backend_init();
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) {
+ return env->NewStringUTF(llama_print_system_info());
+}
+
+extern "C"
+JNIEXPORT jint JNICALL
+Java_android_llama_cpp_LLamaAndroid_completion_1init(
+ JNIEnv *env,
+ jobject,
+ jlong context_pointer,
+ jlong batch_pointer,
+ jstring jtext,
+ jint n_len
+ ) {
+
+ cached_token_chars.clear();
+
+ const auto text = env->GetStringUTFChars(jtext, 0);
+ const auto context = reinterpret_cast<llama_context *>(context_pointer);
+ const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
+
+ const auto tokens_list = llama_tokenize(context, text, 1);
+
+ auto n_ctx = llama_n_ctx(context);
+ auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
+
+ LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
+
+ if (n_kv_req > n_ctx) {
+ LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
+ }
+
+ for (auto id : tokens_list) {
+ LOGi("%s", llama_token_to_piece(context, id).c_str());
+ }
+
+ llama_batch_clear(*batch);
+
+ // evaluate the initial prompt
+ for (auto i = 0; i < tokens_list.size(); i++) {
+ llama_batch_add(*batch, tokens_list[i], i, { 0 }, false);
+ }
+
+ // llama_decode will output logits only for the last token of the prompt
+ batch->logits[batch->n_tokens - 1] = true;
+
+ if (llama_decode(context, *batch) != 0) {
+ LOGe("llama_decode() failed");
+ }
+
+ env->ReleaseStringUTFChars(jtext, text);
+
+ return batch->n_tokens;
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_android_llama_cpp_LLamaAndroid_completion_1loop(
+ JNIEnv * env,
+ jobject,
+ jlong context_pointer,
+ jlong batch_pointer,
+ jint n_len,
+ jobject intvar_ncur
+) {
+ const auto context = reinterpret_cast<llama_context *>(context_pointer);
+ const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
+ const auto model = llama_get_model(context);
+
+ if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
+ if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
+ if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
+
+ auto n_vocab = llama_n_vocab(model);
+ auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
+
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
+ }
+
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+
+ // sample the most likely token
+ const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
+
+ const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
+ if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
+ return env->NewStringUTF("");
+ }
+
+ auto new_token_chars = llama_token_to_piece(context, new_token_id);
+ cached_token_chars += new_token_chars;
+
+ jstring new_token = nullptr;
+ if (is_valid_utf8(cached_token_chars.c_str())) {
+ new_token = env->NewStringUTF(cached_token_chars.c_str());
+ LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
+ cached_token_chars.clear();
+ } else {
+ new_token = env->NewStringUTF("");
+ }
+
+ llama_batch_clear(*batch);
+ llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
+
+ env->CallVoidMethod(intvar_ncur, la_int_var_inc);
+
+ if (llama_decode(context, *batch) != 0) {
+ LOGe("llama_decode() returned null");
+ }
+
+ return new_token;
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
+ llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
+}
--- /dev/null
+package android.llama.cpp
+
+import android.util.Log
+import kotlinx.coroutines.CoroutineDispatcher
+import kotlinx.coroutines.asCoroutineDispatcher
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.flow
+import kotlinx.coroutines.flow.flowOn
+import kotlinx.coroutines.withContext
+import java.util.concurrent.Executors
+import kotlin.concurrent.thread
+
+class LLamaAndroid {
+ private val tag: String? = this::class.simpleName
+
+ private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
+
+ private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
+ thread(start = false, name = "Llm-RunLoop") {
+ Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
+
+ // No-op if called more than once.
+ System.loadLibrary("llama-android")
+
+ // Set llama log handler to Android
+ log_to_android()
+ backend_init(false)
+
+ Log.d(tag, system_info())
+
+ it.run()
+ }.apply {
+ uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
+ Log.e(tag, "Unhandled exception", exception)
+ }
+ }
+ }.asCoroutineDispatcher()
+
+ private val nlen: Int = 64
+
+ private external fun log_to_android()
+ private external fun load_model(filename: String): Long
+ private external fun free_model(model: Long)
+ private external fun new_context(model: Long): Long
+ private external fun free_context(context: Long)
+ private external fun backend_init(numa: Boolean)
+ private external fun backend_free()
+ private external fun free_batch(batch: Long)
+ private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
+ private external fun bench_model(
+ context: Long,
+ model: Long,
+ batch: Long,
+ pp: Int,
+ tg: Int,
+ pl: Int,
+ nr: Int
+ ): String
+
+ private external fun system_info(): String
+
+ private external fun completion_init(
+ context: Long,
+ batch: Long,
+ text: String,
+ nLen: Int
+ ): Int
+
+ private external fun completion_loop(
+ context: Long,
+ batch: Long,
+ nLen: Int,
+ ncur: IntVar
+ ): String?
+
+ private external fun kv_cache_clear(context: Long)
+
+ suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
+ return withContext(runLoop) {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ Log.d(tag, "bench(): $state")
+ bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
+ }
+
+ else -> throw IllegalStateException("No model loaded")
+ }
+ }
+ }
+
+ suspend fun load(pathToModel: String) {
+ withContext(runLoop) {
+ when (threadLocalState.get()) {
+ is State.Idle -> {
+ val model = load_model(pathToModel)
+ if (model == 0L) throw IllegalStateException("load_model() failed")
+
+ val context = new_context(model)
+ if (context == 0L) throw IllegalStateException("new_context() failed")
+
+ val batch = new_batch(512, 0, 1)
+ if (batch == 0L) throw IllegalStateException("new_batch() failed")
+
+ Log.i(tag, "Loaded model $pathToModel")
+ threadLocalState.set(State.Loaded(model, context, batch))
+ }
+ else -> throw IllegalStateException("Model already loaded")
+ }
+ }
+ }
+
+ fun send(message: String): Flow<String> = flow {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
+ while (ncur.value <= nlen) {
+ val str = completion_loop(state.context, state.batch, nlen, ncur)
+ if (str == null) {
+ break
+ }
+ emit(str)
+ }
+ kv_cache_clear(state.context)
+ }
+ else -> {}
+ }
+ }.flowOn(runLoop)
+
+ /**
+ * Unloads the model and frees resources.
+ *
+ * This is a no-op if there's no model loaded.
+ */
+ suspend fun unload() {
+ withContext(runLoop) {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ free_context(state.context)
+ free_model(state.model)
+ free_batch(state.batch)
+
+ threadLocalState.set(State.Idle)
+ }
+ else -> {}
+ }
+ }
+ }
+
+ companion object {
+ private class IntVar(value: Int) {
+ @Volatile
+ var value: Int = value
+ private set
+
+ fun inc() {
+ synchronized(this) {
+ value += 1
+ }
+ }
+ }
+
+ private sealed interface State {
+ data object Idle: State
+ data class Loaded(val model: Long, val context: Long, val batch: Long): State
+ }
+
+ // Enforce only one instance of Llm.
+ private val _instance: LLamaAndroid = LLamaAndroid()
+
+ fun instance(): LLamaAndroid = _instance
+ }
+}
--- /dev/null
+package android.llama.cpp
+
+import org.junit.Test
+
+import org.junit.Assert.*
+
+/**
+ * Example local unit test, which will execute on the development machine (host).
+ *
+ * See [testing documentation](http://d.android.com/tools/testing).
+ */
+class ExampleUnitTest {
+ @Test
+ fun addition_isCorrect() {
+ assertEquals(4, 2 + 2)
+ }
+}
rootProject.name = "LlamaAndroid"
include(":app")
+include(":llama")