From: Brian Date: Tue, 14 May 2024 13:10:39 +0000 (+1000) Subject: Revert "move ndk code to a new library (#6951)" (#7282) X-Git-Tag: upstream/0.0.4488~1609 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=1265c670fd8e41e1947352c96c5179adda97fb2c;p=pkg%2Fggml%2Fsources%2Fllama.cpp Revert "move ndk code to a new library (#6951)" (#7282) This reverts commit efc8f767c8c8c749a245dd96ad4e2f37c164b54c. --- diff --git a/examples/llama.android/app/build.gradle.kts b/examples/llama.android/app/build.gradle.kts index 8d1b3719..d42140ef 100644 --- a/examples/llama.android/app/build.gradle.kts +++ b/examples/llama.android/app/build.gradle.kts @@ -7,6 +7,8 @@ android { namespace = "com.example.llama" compileSdk = 34 + ndkVersion = "26.1.10909125" + defaultConfig { applicationId = "com.example.llama" minSdk = 33 @@ -18,6 +20,17 @@ android { vectorDrawables { useSupportLibrary = true } + ndk { + // Add NDK properties if wanted, e.g. + // abiFilters += listOf("arm64-v8a") + } + externalNativeBuild { + cmake { + arguments += "-DCMAKE_BUILD_TYPE=Release" + cppFlags += listOf() + arguments += listOf() + } + } } buildTypes { @@ -42,6 +55,17 @@ android { composeOptions { kotlinCompilerExtensionVersion = "1.5.1" } + packaging { + resources { + excludes += "/META-INF/{AL2.0,LGPL2.1}" + } + } + externalNativeBuild { + cmake { + path = file("src/main/cpp/CMakeLists.txt") + version = "3.22.1" + } + } } dependencies { @@ -54,7 +78,6 @@ dependencies { implementation("androidx.compose.ui:ui-graphics") implementation("androidx.compose.ui:ui-tooling-preview") implementation("androidx.compose.material3:material3") - implementation(project(":llama")) testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.1.5") androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") diff --git a/examples/llama.android/app/src/main/cpp/CMakeLists.txt b/examples/llama.android/app/src/main/cpp/CMakeLists.txt new file mode 100644 index 00000000..85139329 --- /dev/null +++ b/examples/llama.android/app/src/main/cpp/CMakeLists.txt @@ -0,0 +1,50 @@ + +# For more information about using CMake with Android Studio, read the +# documentation: https://d.android.com/studio/projects/add-native-code.html. +# For more examples on how to use CMake, see https://github.com/android/ndk-samples. + +# Sets the minimum CMake version required for this project. +cmake_minimum_required(VERSION 3.22.1) + +# Declares the project name. The project name can be accessed via ${ PROJECT_NAME}, +# Since this is the top level CMakeLists.txt, the project name is also accessible +# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level +# build script scope). +project("llama-android") + +include(FetchContent) +FetchContent_Declare( + llama + GIT_REPOSITORY https://github.com/ggerganov/llama.cpp + GIT_TAG master +) + +# Also provides "common" +FetchContent_MakeAvailable(llama) + +# Creates and names a library, sets it as either STATIC +# or SHARED, and provides the relative paths to its source code. +# You can define multiple libraries, and CMake builds them for you. +# Gradle automatically packages shared libraries with your APK. +# +# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define +# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME} +# is preferred for the same purpose. +# +# In order to load a library into your app from Java/Kotlin, you must call +# System.loadLibrary() and pass the name of the library defined here; +# for GameActivity/NativeActivity derived applications, the same library name must be +# used in the AndroidManifest.xml file. +add_library(${CMAKE_PROJECT_NAME} SHARED + # List C/C++ source files with relative paths to this CMakeLists.txt. + llama-android.cpp) + +# Specifies libraries CMake should link to your target library. You +# can link libraries from various origins, such as libraries defined in this +# build script, prebuilt third-party libraries, or Android system libraries. +target_link_libraries(${CMAKE_PROJECT_NAME} + # List libraries link to the target library + llama + common + android + log) diff --git a/examples/llama.android/app/src/main/cpp/llama-android.cpp b/examples/llama.android/app/src/main/cpp/llama-android.cpp new file mode 100644 index 00000000..4af9de30 --- /dev/null +++ b/examples/llama.android/app/src/main/cpp/llama-android.cpp @@ -0,0 +1,443 @@ +#include +#include +#include +#include +#include +#include +#include "llama.h" +#include "common/common.h" + +// Write C++ code here. +// +// Do not forget to dynamically load the C++ library into your application. +// +// For instance, +// +// In MainActivity.java: +// static { +// System.loadLibrary("llama-android"); +// } +// +// Or, in MainActivity.kt: +// companion object { +// init { +// System.loadLibrary("llama-android") +// } +// } + +#define TAG "llama-android.cpp" +#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) + +jclass la_int_var; +jmethodID la_int_var_value; +jmethodID la_int_var_inc; + +std::string cached_token_chars; + +bool is_valid_utf8(const char * string) { + if (!string) { + return true; + } + + const unsigned char * bytes = (const unsigned char *)string; + int num; + + while (*bytes != 0x00) { + if ((*bytes & 0x80) == 0x00) { + // U+0000 to U+007F + num = 1; + } else if ((*bytes & 0xE0) == 0xC0) { + // U+0080 to U+07FF + num = 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // U+0800 to U+FFFF + num = 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // U+10000 to U+10FFFF + num = 4; + } else { + return false; + } + + bytes += 1; + for (int i = 1; i < num; ++i) { + if ((*bytes & 0xC0) != 0x80) { + return false; + } + bytes += 1; + } + } + + return true; +} + +static void log_callback(ggml_log_level level, const char * fmt, void * data) { + if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); + else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); + else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data); + else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data); +} + +extern "C" +JNIEXPORT jlong JNICALL +Java_com_example_llama_Llm_load_1model(JNIEnv *env, jobject, jstring filename) { + llama_model_params model_params = llama_model_default_params(); + + auto path_to_model = env->GetStringUTFChars(filename, 0); + LOGi("Loading model from %s", path_to_model); + + auto model = llama_load_model_from_file(path_to_model, model_params); + env->ReleaseStringUTFChars(filename, path_to_model); + + if (!model) { + LOGe("load_model() failed"); + env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed"); + return 0; + } + + return reinterpret_cast(model); +} + +extern "C" +JNIEXPORT void JNICALL +Java_com_example_llama_Llm_free_1model(JNIEnv *, jobject, jlong model) { + llama_free_model(reinterpret_cast(model)); +} + +extern "C" +JNIEXPORT jlong JNICALL +Java_com_example_llama_Llm_new_1context(JNIEnv *env, jobject, jlong jmodel) { + auto model = reinterpret_cast(jmodel); + + if (!model) { + LOGe("new_context(): model cannot be null"); + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null"); + return 0; + } + + int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2)); + LOGi("Using %d threads", n_threads); + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.seed = 1234; + ctx_params.n_ctx = 2048; + ctx_params.n_threads = n_threads; + ctx_params.n_threads_batch = n_threads; + + llama_context * context = llama_new_context_with_model(model, ctx_params); + + if (!context) { + LOGe("llama_new_context_with_model() returned null)"); + env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), + "llama_new_context_with_model() returned null)"); + return 0; + } + + return reinterpret_cast(context); +} + +extern "C" +JNIEXPORT void JNICALL +Java_com_example_llama_Llm_free_1context(JNIEnv *, jobject, jlong context) { + llama_free(reinterpret_cast(context)); +} + +extern "C" +JNIEXPORT void JNICALL +Java_com_example_llama_Llm_backend_1free(JNIEnv *, jobject) { + llama_backend_free(); +} + +extern "C" +JNIEXPORT void JNICALL +Java_com_example_llama_Llm_log_1to_1android(JNIEnv *, jobject) { + llama_log_set(log_callback, NULL); +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_example_llama_Llm_bench_1model( + JNIEnv *env, + jobject, + jlong context_pointer, + jlong model_pointer, + jlong batch_pointer, + jint pp, + jint tg, + jint pl, + jint nr + ) { + auto pp_avg = 0.0; + auto tg_avg = 0.0; + auto pp_std = 0.0; + auto tg_std = 0.0; + + const auto context = reinterpret_cast(context_pointer); + const auto model = reinterpret_cast(model_pointer); + const auto batch = reinterpret_cast(batch_pointer); + + const int n_ctx = llama_n_ctx(context); + + LOGi("n_ctx = %d", n_ctx); + + int i, j; + int nri; + for (nri = 0; nri < nr; nri++) { + LOGi("Benchmark prompt processing (pp)"); + + llama_batch_clear(*batch); + + const int n_tokens = pp; + for (i = 0; i < n_tokens; i++) { + llama_batch_add(*batch, 0, i, { 0 }, false); + } + + batch->logits[batch->n_tokens - 1] = true; + llama_kv_cache_clear(context); + + const auto t_pp_start = ggml_time_us(); + if (llama_decode(context, *batch) != 0) { + LOGi("llama_decode() failed during prompt processing"); + } + const auto t_pp_end = ggml_time_us(); + + // bench text generation + + LOGi("Benchmark text generation (tg)"); + + llama_kv_cache_clear(context); + const auto t_tg_start = ggml_time_us(); + for (i = 0; i < tg; i++) { + + llama_batch_clear(*batch); + for (j = 0; j < pl; j++) { + llama_batch_add(*batch, 0, i, { j }, true); + } + + LOGi("llama_decode() text generation: %d", i); + if (llama_decode(context, *batch) != 0) { + LOGi("llama_decode() failed during text generation"); + } + } + + const auto t_tg_end = ggml_time_us(); + + llama_kv_cache_clear(context); + + const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; + const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; + + const auto speed_pp = double(pp) / t_pp; + const auto speed_tg = double(pl * tg) / t_tg; + + pp_avg += speed_pp; + tg_avg += speed_tg; + + pp_std += speed_pp * speed_pp; + tg_std += speed_tg * speed_tg; + + LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); + } + + pp_avg /= double(nr); + tg_avg /= double(nr); + + if (nr > 1) { + pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); + tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); + } else { + pp_std = 0; + tg_std = 0; + } + + char model_desc[128]; + llama_model_desc(model, model_desc, sizeof(model_desc)); + + const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; + const auto model_n_params = double(llama_model_n_params(model)) / 1e9; + + const auto backend = "(Android)"; // TODO: What should this be? + + std::stringstream result; + result << std::setprecision(2); + result << "| model | size | params | backend | test | t/s |\n"; + result << "| --- | --- | --- | --- | --- | --- |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; + + return env->NewStringUTF(result.str().c_str()); +} + +extern "C" +JNIEXPORT void JNICALL +Java_com_example_llama_Llm_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { + llama_batch_free(*reinterpret_cast(batch_pointer)); +} + +extern "C" +JNIEXPORT jlong JNICALL +Java_com_example_llama_Llm_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { + + // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. + + llama_batch *batch = new llama_batch { + 0, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + 0, + 0, + 0, + }; + + if (embd) { + batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); + } else { + batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); + } + + batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); + for (int i = 0; i < n_tokens; ++i) { + batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return reinterpret_cast(batch); +} + +extern "C" +JNIEXPORT void JNICALL +Java_com_example_llama_Llm_backend_1init(JNIEnv *, jobject) { + llama_backend_init(); +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_example_llama_Llm_system_1info(JNIEnv *env, jobject) { + return env->NewStringUTF(llama_print_system_info()); +} + +extern "C" +JNIEXPORT jint JNICALL +Java_com_example_llama_Llm_completion_1init( + JNIEnv *env, + jobject, + jlong context_pointer, + jlong batch_pointer, + jstring jtext, + jint n_len + ) { + + cached_token_chars.clear(); + + const auto text = env->GetStringUTFChars(jtext, 0); + const auto context = reinterpret_cast(context_pointer); + const auto batch = reinterpret_cast(batch_pointer); + + const auto tokens_list = llama_tokenize(context, text, 1); + + auto n_ctx = llama_n_ctx(context); + auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); + + LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); + + if (n_kv_req > n_ctx) { + LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); + } + + for (auto id : tokens_list) { + LOGi("%s", llama_token_to_piece(context, id).c_str()); + } + + llama_batch_clear(*batch); + + // evaluate the initial prompt + for (auto i = 0; i < tokens_list.size(); i++) { + llama_batch_add(*batch, tokens_list[i], i, { 0 }, false); + } + + // llama_decode will output logits only for the last token of the prompt + batch->logits[batch->n_tokens - 1] = true; + + if (llama_decode(context, *batch) != 0) { + LOGe("llama_decode() failed"); + } + + env->ReleaseStringUTFChars(jtext, text); + + return batch->n_tokens; +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_example_llama_Llm_completion_1loop( + JNIEnv * env, + jobject, + jlong context_pointer, + jlong batch_pointer, + jint n_len, + jobject intvar_ncur +) { + const auto context = reinterpret_cast(context_pointer); + const auto batch = reinterpret_cast(batch_pointer); + const auto model = llama_get_model(context); + + if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); + if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); + if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); + + auto n_vocab = llama_n_vocab(model); + auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); + + std::vector candidates; + candidates.reserve(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // sample the most likely token + const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); + + const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); + if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { + return env->NewStringUTF(""); + } + + auto new_token_chars = llama_token_to_piece(context, new_token_id); + cached_token_chars += new_token_chars; + + jstring new_token = nullptr; + if (is_valid_utf8(cached_token_chars.c_str())) { + new_token = env->NewStringUTF(cached_token_chars.c_str()); + LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id); + cached_token_chars.clear(); + } else { + new_token = env->NewStringUTF(""); + } + + llama_batch_clear(*batch); + llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true); + + env->CallVoidMethod(intvar_ncur, la_int_var_inc); + + if (llama_decode(context, *batch) != 0) { + LOGe("llama_decode() returned null"); + } + + return new_token; +} + +extern "C" +JNIEXPORT void JNICALL +Java_com_example_llama_Llm_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { + llama_kv_cache_clear(reinterpret_cast(context)); +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt new file mode 100644 index 00000000..d86afee3 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt @@ -0,0 +1,172 @@ +package com.example.llama + +import android.util.Log +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.withContext +import java.util.concurrent.Executors +import kotlin.concurrent.thread + +class Llm { + private val tag: String? = this::class.simpleName + + private val threadLocalState: ThreadLocal = ThreadLocal.withInitial { State.Idle } + + private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor { + thread(start = false, name = "Llm-RunLoop") { + Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}") + + // No-op if called more than once. + System.loadLibrary("llama-android") + + // Set llama log handler to Android + log_to_android() + backend_init(false) + + Log.d(tag, system_info()) + + it.run() + }.apply { + uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable -> + Log.e(tag, "Unhandled exception", exception) + } + } + }.asCoroutineDispatcher() + + private val nlen: Int = 64 + + private external fun log_to_android() + private external fun load_model(filename: String): Long + private external fun free_model(model: Long) + private external fun new_context(model: Long): Long + private external fun free_context(context: Long) + private external fun backend_init(numa: Boolean) + private external fun backend_free() + private external fun free_batch(batch: Long) + private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long + private external fun bench_model( + context: Long, + model: Long, + batch: Long, + pp: Int, + tg: Int, + pl: Int, + nr: Int + ): String + + private external fun system_info(): String + + private external fun completion_init( + context: Long, + batch: Long, + text: String, + nLen: Int + ): Int + + private external fun completion_loop( + context: Long, + batch: Long, + nLen: Int, + ncur: IntVar + ): String? + + private external fun kv_cache_clear(context: Long) + + suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { + return withContext(runLoop) { + when (val state = threadLocalState.get()) { + is State.Loaded -> { + Log.d(tag, "bench(): $state") + bench_model(state.context, state.model, state.batch, pp, tg, pl, nr) + } + + else -> throw IllegalStateException("No model loaded") + } + } + } + + suspend fun load(pathToModel: String) { + withContext(runLoop) { + when (threadLocalState.get()) { + is State.Idle -> { + val model = load_model(pathToModel) + if (model == 0L) throw IllegalStateException("load_model() failed") + + val context = new_context(model) + if (context == 0L) throw IllegalStateException("new_context() failed") + + val batch = new_batch(512, 0, 1) + if (batch == 0L) throw IllegalStateException("new_batch() failed") + + Log.i(tag, "Loaded model $pathToModel") + threadLocalState.set(State.Loaded(model, context, batch)) + } + else -> throw IllegalStateException("Model already loaded") + } + } + } + + fun send(message: String): Flow = flow { + when (val state = threadLocalState.get()) { + is State.Loaded -> { + val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) + while (ncur.value <= nlen) { + val str = completion_loop(state.context, state.batch, nlen, ncur) + if (str == null) { + break + } + emit(str) + } + kv_cache_clear(state.context) + } + else -> {} + } + }.flowOn(runLoop) + + /** + * Unloads the model and frees resources. + * + * This is a no-op if there's no model loaded. + */ + suspend fun unload() { + withContext(runLoop) { + when (val state = threadLocalState.get()) { + is State.Loaded -> { + free_context(state.context) + free_model(state.model) + free_batch(state.batch) + + threadLocalState.set(State.Idle) + } + else -> {} + } + } + } + + companion object { + private class IntVar(value: Int) { + @Volatile + var value: Int = value + private set + + fun inc() { + synchronized(this) { + value += 1 + } + } + } + + private sealed interface State { + data object Idle: State + data class Loaded(val model: Long, val context: Long, val batch: Long): State + } + + // Enforce only one instance of Llm. + private val _instance: Llm = Llm() + + fun instance(): Llm = _instance + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt index 45ac2993..be95e222 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt @@ -1,6 +1,5 @@ package com.example.llama -import android.llama.cpp.LLamaAndroid import android.util.Log import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf @@ -10,7 +9,7 @@ import androidx.lifecycle.viewModelScope import kotlinx.coroutines.flow.catch import kotlinx.coroutines.launch -class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() { +class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { companion object { @JvmStatic private val NanosPerSecond = 1_000_000_000.0 @@ -29,7 +28,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan viewModelScope.launch { try { - llamaAndroid.unload() + llm.unload() } catch (exc: IllegalStateException) { messages += exc.message!! } @@ -45,7 +44,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan messages += "" viewModelScope.launch { - llamaAndroid.send(text) + llm.send(text) .catch { Log.e(tag, "send() failed", it) messages += it.message!! @@ -58,7 +57,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan viewModelScope.launch { try { val start = System.nanoTime() - val warmupResult = llamaAndroid.bench(pp, tg, pl, nr) + val warmupResult = llm.bench(pp, tg, pl, nr) val end = System.nanoTime() messages += warmupResult @@ -71,7 +70,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan return@launch } - messages += llamaAndroid.bench(512, 128, 1, 3) + messages += llm.bench(512, 128, 1, 3) } catch (exc: IllegalStateException) { Log.e(tag, "bench() failed", exc) messages += exc.message!! @@ -82,7 +81,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan fun load(pathToModel: String) { viewModelScope.launch { try { - llamaAndroid.load(pathToModel) + llm.load(pathToModel) messages += "Loaded $pathToModel" } catch (exc: IllegalStateException) { Log.e(tag, "load() failed", exc) diff --git a/examples/llama.android/build.gradle.kts b/examples/llama.android/build.gradle.kts index acd1ada7..50ebc821 100644 --- a/examples/llama.android/build.gradle.kts +++ b/examples/llama.android/build.gradle.kts @@ -2,5 +2,4 @@ plugins { id("com.android.application") version "8.2.0" apply false id("org.jetbrains.kotlin.android") version "1.9.0" apply false - id("com.android.library") version "8.2.0" apply false } diff --git a/examples/llama.android/llama/.gitignore b/examples/llama.android/llama/.gitignore deleted file mode 100644 index 796b96d1..00000000 --- a/examples/llama.android/llama/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/build diff --git a/examples/llama.android/llama/CMakeLists.txt b/examples/llama.android/llama/CMakeLists.txt deleted file mode 100644 index bb5738ae..00000000 --- a/examples/llama.android/llama/CMakeLists.txt +++ /dev/null @@ -1,50 +0,0 @@ - -# For more information about using CMake with Android Studio, read the -# documentation: https://d.android.com/studio/projects/add-native-code.html. -# For more examples on how to use CMake, see https://github.com/android/ndk-samples. - -# Sets the minimum CMake version required for this project. -cmake_minimum_required(VERSION 3.22.1) - -# Declares the project name. The project name can be accessed via ${ PROJECT_NAME}, -# Since this is the top level CMakeLists.txt, the project name is also accessible -# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level -# build script scope). -project("llama-android") - -include(FetchContent) -FetchContent_Declare( - llama - GIT_REPOSITORY https://github.com/ggerganov/llama.cpp - GIT_TAG master -) - -# Also provides "common" -FetchContent_MakeAvailable(llama) - -# Creates and names a library, sets it as either STATIC -# or SHARED, and provides the relative paths to its source code. -# You can define multiple libraries, and CMake builds them for you. -# Gradle automatically packages shared libraries with your APK. -# -# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define -# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME} -# is preferred for the same purpose. -# -# In order to load a library into your app from Java/Kotlin, you must call -# System.loadLibrary() and pass the name of the library defined here; -# for GameActivity/NativeActivity derived applications, the same library name must be -# used in the AndroidManifest.xml file. -add_library(${CMAKE_PROJECT_NAME} SHARED - # List C/C++ source files with relative paths to this CMakeLists.txt. - llama-android.cpp) - -# Specifies libraries CMake should link to your target library. You -# can link libraries from various origins, such as libraries defined in this -# build script, prebuilt third-party libraries, or Android system libraries. -target_link_libraries(${CMAKE_PROJECT_NAME} - # List libraries link to the target library - llama - common - android - log) diff --git a/examples/llama.android/llama/consumer-rules.pro b/examples/llama.android/llama/consumer-rules.pro deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/llama.android/llama/proguard-rules.pro b/examples/llama.android/llama/proguard-rules.pro deleted file mode 100644 index f1b42451..00000000 --- a/examples/llama.android/llama/proguard-rules.pro +++ /dev/null @@ -1,21 +0,0 @@ -# Add project specific ProGuard rules here. -# You can control the set of applied configuration files using the -# proguardFiles setting in build.gradle. -# -# For more details, see -# http://developer.android.com/guide/developing/tools/proguard.html - -# If your project uses WebView with JS, uncomment the following -# and specify the fully qualified class name to the JavaScript interface -# class: -#-keepclassmembers class fqcn.of.javascript.interface.for.webview { -# public *; -#} - -# Uncomment this to preserve the line number information for -# debugging stack traces. -#-keepattributes SourceFile,LineNumberTable - -# If you keep the line number information, uncomment this to -# hide the original source file name. -#-renamesourcefileattribute SourceFile diff --git a/examples/llama.android/llama/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt b/examples/llama.android/llama/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt deleted file mode 100644 index 05d6ab5d..00000000 --- a/examples/llama.android/llama/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt +++ /dev/null @@ -1,24 +0,0 @@ -package android.llama.cpp - -import androidx.test.platform.app.InstrumentationRegistry -import androidx.test.ext.junit.runners.AndroidJUnit4 - -import org.junit.Test -import org.junit.runner.RunWith - -import org.junit.Assert.* - -/** - * Instrumented test, which will execute on an Android device. - * - * See [testing documentation](http://d.android.com/tools/testing). - */ -@RunWith(AndroidJUnit4::class) -class ExampleInstrumentedTest { - @Test - fun useAppContext() { - // Context of the app under test. - val appContext = InstrumentationRegistry.getInstrumentation().targetContext - assertEquals("android.llama.cpp.test", appContext.packageName) - } -} diff --git a/examples/llama.android/llama/src/main/AndroidManifest.xml b/examples/llama.android/llama/src/main/AndroidManifest.xml deleted file mode 100644 index 8bdb7e14..00000000 --- a/examples/llama.android/llama/src/main/AndroidManifest.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - - diff --git a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt deleted file mode 100644 index 42ebaad4..00000000 --- a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -# For more information about using CMake with Android Studio, read the -# documentation: https://d.android.com/studio/projects/add-native-code.html. -# For more examples on how to use CMake, see https://github.com/android/ndk-samples. - -# Sets the minimum CMake version required for this project. -cmake_minimum_required(VERSION 3.22.1) - -# Declares the project name. The project name can be accessed via ${ PROJECT_NAME}, -# Since this is the top level CMakeLists.txt, the project name is also accessible -# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level -# build script scope). -project("llama-android") - -include(FetchContent) -FetchContent_Declare( - llama - GIT_REPOSITORY https://github.com/ggerganov/llama.cpp - GIT_TAG master -) - -# Also provides "common" -FetchContent_MakeAvailable(llama) - -# Creates and names a library, sets it as either STATIC -# or SHARED, and provides the relative paths to its source code. -# You can define multiple libraries, and CMake builds them for you. -# Gradle automatically packages shared libraries with your APK. -# -# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define -# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME} -# is preferred for the same purpose. -# -# In order to load a library into your app from Java/Kotlin, you must call -# System.loadLibrary() and pass the name of the library defined here; -# for GameActivity/NativeActivity derived applications, the same library name must be -# used in the AndroidManifest.xml file. -add_library(${CMAKE_PROJECT_NAME} SHARED - # List C/C++ source files with relative paths to this CMakeLists.txt. - llama-android.cpp) - -# Specifies libraries CMake should link to your target library. You -# can link libraries from various origins, such as libraries defined in this -# build script, prebuilt third-party libraries, or Android system libraries. -target_link_libraries(${CMAKE_PROJECT_NAME} - # List libraries link to the target library - llama - common - android - log) diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp deleted file mode 100644 index 874158ef..00000000 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ /dev/null @@ -1,443 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "llama.h" -#include "common/common.h" - -// Write C++ code here. -// -// Do not forget to dynamically load the C++ library into your application. -// -// For instance, -// -// In MainActivity.java: -// static { -// System.loadLibrary("llama-android"); -// } -// -// Or, in MainActivity.kt: -// companion object { -// init { -// System.loadLibrary("llama-android") -// } -// } - -#define TAG "llama-android.cpp" -#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) -#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) - -jclass la_int_var; -jmethodID la_int_var_value; -jmethodID la_int_var_inc; - -std::string cached_token_chars; - -bool is_valid_utf8(const char * string) { - if (!string) { - return true; - } - - const unsigned char * bytes = (const unsigned char *)string; - int num; - - while (*bytes != 0x00) { - if ((*bytes & 0x80) == 0x00) { - // U+0000 to U+007F - num = 1; - } else if ((*bytes & 0xE0) == 0xC0) { - // U+0080 to U+07FF - num = 2; - } else if ((*bytes & 0xF0) == 0xE0) { - // U+0800 to U+FFFF - num = 3; - } else if ((*bytes & 0xF8) == 0xF0) { - // U+10000 to U+10FFFF - num = 4; - } else { - return false; - } - - bytes += 1; - for (int i = 1; i < num; ++i) { - if ((*bytes & 0xC0) != 0x80) { - return false; - } - bytes += 1; - } - } - - return true; -} - -static void log_callback(ggml_log_level level, const char * fmt, void * data) { - if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); - else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); - else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data); - else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data); -} - -extern "C" -JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) { - llama_model_params model_params = llama_model_default_params(); - - auto path_to_model = env->GetStringUTFChars(filename, 0); - LOGi("Loading model from %s", path_to_model); - - auto model = llama_load_model_from_file(path_to_model, model_params); - env->ReleaseStringUTFChars(filename, path_to_model); - - if (!model) { - LOGe("load_model() failed"); - env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed"); - return 0; - } - - return reinterpret_cast(model); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { - llama_free_model(reinterpret_cast(model)); -} - -extern "C" -JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) { - auto model = reinterpret_cast(jmodel); - - if (!model) { - LOGe("new_context(): model cannot be null"); - env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null"); - return 0; - } - - int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2)); - LOGi("Using %d threads", n_threads); - - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; - ctx_params.n_threads = n_threads; - ctx_params.n_threads_batch = n_threads; - - llama_context * context = llama_new_context_with_model(model, ctx_params); - - if (!context) { - LOGe("llama_new_context_with_model() returned null)"); - env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), - "llama_new_context_with_model() returned null)"); - return 0; - } - - return reinterpret_cast(context); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) { - llama_free(reinterpret_cast(context)); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) { - llama_backend_free(); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) { - llama_log_set(log_callback, NULL); -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_bench_1model( - JNIEnv *env, - jobject, - jlong context_pointer, - jlong model_pointer, - jlong batch_pointer, - jint pp, - jint tg, - jint pl, - jint nr - ) { - auto pp_avg = 0.0; - auto tg_avg = 0.0; - auto pp_std = 0.0; - auto tg_std = 0.0; - - const auto context = reinterpret_cast(context_pointer); - const auto model = reinterpret_cast(model_pointer); - const auto batch = reinterpret_cast(batch_pointer); - - const int n_ctx = llama_n_ctx(context); - - LOGi("n_ctx = %d", n_ctx); - - int i, j; - int nri; - for (nri = 0; nri < nr; nri++) { - LOGi("Benchmark prompt processing (pp)"); - - llama_batch_clear(*batch); - - const int n_tokens = pp; - for (i = 0; i < n_tokens; i++) { - llama_batch_add(*batch, 0, i, { 0 }, false); - } - - batch->logits[batch->n_tokens - 1] = true; - llama_kv_cache_clear(context); - - const auto t_pp_start = ggml_time_us(); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during prompt processing"); - } - const auto t_pp_end = ggml_time_us(); - - // bench text generation - - LOGi("Benchmark text generation (tg)"); - - llama_kv_cache_clear(context); - const auto t_tg_start = ggml_time_us(); - for (i = 0; i < tg; i++) { - - llama_batch_clear(*batch); - for (j = 0; j < pl; j++) { - llama_batch_add(*batch, 0, i, { j }, true); - } - - LOGi("llama_decode() text generation: %d", i); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during text generation"); - } - } - - const auto t_tg_end = ggml_time_us(); - - llama_kv_cache_clear(context); - - const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; - const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; - - const auto speed_pp = double(pp) / t_pp; - const auto speed_tg = double(pl * tg) / t_tg; - - pp_avg += speed_pp; - tg_avg += speed_tg; - - pp_std += speed_pp * speed_pp; - tg_std += speed_tg * speed_tg; - - LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); - } - - pp_avg /= double(nr); - tg_avg /= double(nr); - - if (nr > 1) { - pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); - tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); - } else { - pp_std = 0; - tg_std = 0; - } - - char model_desc[128]; - llama_model_desc(model, model_desc, sizeof(model_desc)); - - const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0; - const auto model_n_params = double(llama_model_n_params(model)) / 1e9; - - const auto backend = "(Android)"; // TODO: What should this be? - - std::stringstream result; - result << std::setprecision(2); - result << "| model | size | params | backend | test | t/s |\n"; - result << "| --- | --- | --- | --- | --- | --- |\n"; - result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; - result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; - - return env->NewStringUTF(result.str().c_str()); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - llama_batch_free(*reinterpret_cast(batch_pointer)); -} - -extern "C" -JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { - - // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. - - llama_batch *batch = new llama_batch { - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - 0, - 0, - 0, - }; - - if (embd) { - batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); - } else { - batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); - } - - batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); - batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); - for (int i = 0; i < n_tokens; ++i) { - batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); - } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); - - return reinterpret_cast(batch); -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) { - llama_backend_init(); -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) { - return env->NewStringUTF(llama_print_system_info()); -} - -extern "C" -JNIEXPORT jint JNICALL -Java_android_llama_cpp_LLamaAndroid_completion_1init( - JNIEnv *env, - jobject, - jlong context_pointer, - jlong batch_pointer, - jstring jtext, - jint n_len - ) { - - cached_token_chars.clear(); - - const auto text = env->GetStringUTFChars(jtext, 0); - const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); - - const auto tokens_list = llama_tokenize(context, text, 1); - - auto n_ctx = llama_n_ctx(context); - auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); - - LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); - - if (n_kv_req > n_ctx) { - LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); - } - - for (auto id : tokens_list) { - LOGi("%s", llama_token_to_piece(context, id).c_str()); - } - - llama_batch_clear(*batch); - - // evaluate the initial prompt - for (auto i = 0; i < tokens_list.size(); i++) { - llama_batch_add(*batch, tokens_list[i], i, { 0 }, false); - } - - // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; - - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() failed"); - } - - env->ReleaseStringUTFChars(jtext, text); - - return batch->n_tokens; -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_completion_1loop( - JNIEnv * env, - jobject, - jlong context_pointer, - jlong batch_pointer, - jint n_len, - jobject intvar_ncur -) { - const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); - const auto model = llama_get_model(context); - - if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); - if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); - if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); - - auto n_vocab = llama_n_vocab(model); - auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); - - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - // sample the most likely token - const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); - - const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); - if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { - return env->NewStringUTF(""); - } - - auto new_token_chars = llama_token_to_piece(context, new_token_id); - cached_token_chars += new_token_chars; - - jstring new_token = nullptr; - if (is_valid_utf8(cached_token_chars.c_str())) { - new_token = env->NewStringUTF(cached_token_chars.c_str()); - LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id); - cached_token_chars.clear(); - } else { - new_token = env->NewStringUTF(""); - } - - llama_batch_clear(*batch); - llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true); - - env->CallVoidMethod(intvar_ncur, la_int_var_inc); - - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() returned null"); - } - - return new_token; -} - -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_kv_cache_clear(reinterpret_cast(context)); -} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt deleted file mode 100644 index 6c63e54e..00000000 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ /dev/null @@ -1,172 +0,0 @@ -package android.llama.cpp - -import android.util.Log -import kotlinx.coroutines.CoroutineDispatcher -import kotlinx.coroutines.asCoroutineDispatcher -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.flowOn -import kotlinx.coroutines.withContext -import java.util.concurrent.Executors -import kotlin.concurrent.thread - -class LLamaAndroid { - private val tag: String? = this::class.simpleName - - private val threadLocalState: ThreadLocal = ThreadLocal.withInitial { State.Idle } - - private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor { - thread(start = false, name = "Llm-RunLoop") { - Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}") - - // No-op if called more than once. - System.loadLibrary("llama-android") - - // Set llama log handler to Android - log_to_android() - backend_init(false) - - Log.d(tag, system_info()) - - it.run() - }.apply { - uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable -> - Log.e(tag, "Unhandled exception", exception) - } - } - }.asCoroutineDispatcher() - - private val nlen: Int = 64 - - private external fun log_to_android() - private external fun load_model(filename: String): Long - private external fun free_model(model: Long) - private external fun new_context(model: Long): Long - private external fun free_context(context: Long) - private external fun backend_init(numa: Boolean) - private external fun backend_free() - private external fun free_batch(batch: Long) - private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long - private external fun bench_model( - context: Long, - model: Long, - batch: Long, - pp: Int, - tg: Int, - pl: Int, - nr: Int - ): String - - private external fun system_info(): String - - private external fun completion_init( - context: Long, - batch: Long, - text: String, - nLen: Int - ): Int - - private external fun completion_loop( - context: Long, - batch: Long, - nLen: Int, - ncur: IntVar - ): String? - - private external fun kv_cache_clear(context: Long) - - suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { - return withContext(runLoop) { - when (val state = threadLocalState.get()) { - is State.Loaded -> { - Log.d(tag, "bench(): $state") - bench_model(state.context, state.model, state.batch, pp, tg, pl, nr) - } - - else -> throw IllegalStateException("No model loaded") - } - } - } - - suspend fun load(pathToModel: String) { - withContext(runLoop) { - when (threadLocalState.get()) { - is State.Idle -> { - val model = load_model(pathToModel) - if (model == 0L) throw IllegalStateException("load_model() failed") - - val context = new_context(model) - if (context == 0L) throw IllegalStateException("new_context() failed") - - val batch = new_batch(512, 0, 1) - if (batch == 0L) throw IllegalStateException("new_batch() failed") - - Log.i(tag, "Loaded model $pathToModel") - threadLocalState.set(State.Loaded(model, context, batch)) - } - else -> throw IllegalStateException("Model already loaded") - } - } - } - - fun send(message: String): Flow = flow { - when (val state = threadLocalState.get()) { - is State.Loaded -> { - val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) - while (ncur.value <= nlen) { - val str = completion_loop(state.context, state.batch, nlen, ncur) - if (str == null) { - break - } - emit(str) - } - kv_cache_clear(state.context) - } - else -> {} - } - }.flowOn(runLoop) - - /** - * Unloads the model and frees resources. - * - * This is a no-op if there's no model loaded. - */ - suspend fun unload() { - withContext(runLoop) { - when (val state = threadLocalState.get()) { - is State.Loaded -> { - free_context(state.context) - free_model(state.model) - free_batch(state.batch) - - threadLocalState.set(State.Idle) - } - else -> {} - } - } - } - - companion object { - private class IntVar(value: Int) { - @Volatile - var value: Int = value - private set - - fun inc() { - synchronized(this) { - value += 1 - } - } - } - - private sealed interface State { - data object Idle: State - data class Loaded(val model: Long, val context: Long, val batch: Long): State - } - - // Enforce only one instance of Llm. - private val _instance: LLamaAndroid = LLamaAndroid() - - fun instance(): LLamaAndroid = _instance - } -} diff --git a/examples/llama.android/llama/src/test/java/android/llama/cpp/ExampleUnitTest.kt b/examples/llama.android/llama/src/test/java/android/llama/cpp/ExampleUnitTest.kt deleted file mode 100644 index cbbb974d..00000000 --- a/examples/llama.android/llama/src/test/java/android/llama/cpp/ExampleUnitTest.kt +++ /dev/null @@ -1,17 +0,0 @@ -package android.llama.cpp - -import org.junit.Test - -import org.junit.Assert.* - -/** - * Example local unit test, which will execute on the development machine (host). - * - * See [testing documentation](http://d.android.com/tools/testing). - */ -class ExampleUnitTest { - @Test - fun addition_isCorrect() { - assertEquals(4, 2 + 2) - } -} diff --git a/examples/llama.android/settings.gradle.kts b/examples/llama.android/settings.gradle.kts index c7c1a034..2ba32c4f 100644 --- a/examples/llama.android/settings.gradle.kts +++ b/examples/llama.android/settings.gradle.kts @@ -15,4 +15,3 @@ dependencyResolutionManagement { rootProject.name = "LlamaAndroid" include(":app") -include(":llama")