]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
android : module (#7502)
authorElton Kola <redacted>
Sat, 25 May 2024 08:11:33 +0000 (04:11 -0400)
committerGitHub <redacted>
Sat, 25 May 2024 08:11:33 +0000 (11:11 +0300)
* move ndk code to a new library

* add gradle file

18 files changed:
examples/llama.android/app/build.gradle.kts
examples/llama.android/app/src/main/cpp/CMakeLists.txt [deleted file]
examples/llama.android/app/src/main/cpp/llama-android.cpp [deleted file]
examples/llama.android/app/src/main/java/com/example/llama/Llm.kt [deleted file]
examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
examples/llama.android/build.gradle.kts
examples/llama.android/llama/.gitignore [new file with mode: 0644]
examples/llama.android/llama/CMakeLists.txt [new file with mode: 0644]
examples/llama.android/llama/build.gradle.kts [new file with mode: 0644]
examples/llama.android/llama/consumer-rules.pro [new file with mode: 0644]
examples/llama.android/llama/proguard-rules.pro [new file with mode: 0644]
examples/llama.android/llama/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt [new file with mode: 0644]
examples/llama.android/llama/src/main/AndroidManifest.xml [new file with mode: 0644]
examples/llama.android/llama/src/main/cpp/CMakeLists.txt [new file with mode: 0644]
examples/llama.android/llama/src/main/cpp/llama-android.cpp [new file with mode: 0644]
examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt [new file with mode: 0644]
examples/llama.android/llama/src/test/java/android/llama/cpp/ExampleUnitTest.kt [new file with mode: 0644]
examples/llama.android/settings.gradle.kts

index d42140efe816813c6c7f8a0d5f25a45764aa903b..8d1b37195efd40a8336391d80b4d90064dba6ee0 100644 (file)
@@ -7,8 +7,6 @@ android {
     namespace = "com.example.llama"
     compileSdk = 34
 
-    ndkVersion = "26.1.10909125"
-
     defaultConfig {
         applicationId = "com.example.llama"
         minSdk = 33
@@ -20,17 +18,6 @@ android {
         vectorDrawables {
             useSupportLibrary = true
         }
-        ndk {
-            // Add NDK properties if wanted, e.g.
-            // abiFilters += listOf("arm64-v8a")
-        }
-        externalNativeBuild {
-            cmake {
-                arguments += "-DCMAKE_BUILD_TYPE=Release"
-                cppFlags += listOf()
-                arguments += listOf()
-            }
-        }
     }
 
     buildTypes {
@@ -55,17 +42,6 @@ android {
     composeOptions {
         kotlinCompilerExtensionVersion = "1.5.1"
     }
-    packaging {
-        resources {
-            excludes += "/META-INF/{AL2.0,LGPL2.1}"
-        }
-    }
-    externalNativeBuild {
-        cmake {
-            path = file("src/main/cpp/CMakeLists.txt")
-            version = "3.22.1"
-        }
-    }
 }
 
 dependencies {
@@ -78,6 +54,7 @@ dependencies {
     implementation("androidx.compose.ui:ui-graphics")
     implementation("androidx.compose.ui:ui-tooling-preview")
     implementation("androidx.compose.material3:material3")
+    implementation(project(":llama"))
     testImplementation("junit:junit:4.13.2")
     androidTestImplementation("androidx.test.ext:junit:1.1.5")
     androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
diff --git a/examples/llama.android/app/src/main/cpp/CMakeLists.txt b/examples/llama.android/app/src/main/cpp/CMakeLists.txt
deleted file mode 100644 (file)
index 4536974..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-
-# For more information about using CMake with Android Studio, read the
-# documentation: https://d.android.com/studio/projects/add-native-code.html.
-# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
-
-# Sets the minimum CMake version required for this project.
-cmake_minimum_required(VERSION 3.22.1)
-
-# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
-# Since this is the top level CMakeLists.txt, the project name is also accessible
-# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
-# build script scope).
-project("llama-android")
-
-## Fetch latest llama.cpp from GitHub
-#include(FetchContent)
-#FetchContent_Declare(
-#        llama
-#        GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
-#        GIT_TAG        master
-#)
-#
-## Also provides "common"
-#FetchContent_MakeAvailable(llama)
-
-# llama.cpp CI uses the code from the current branch
-# ref: https://github.com/ggerganov/llama.cpp/pull/7341#issuecomment-2117617700
-add_subdirectory(../../../../../../ build-llama)
-
-# Creates and names a library, sets it as either STATIC
-# or SHARED, and provides the relative paths to its source code.
-# You can define multiple libraries, and CMake builds them for you.
-# Gradle automatically packages shared libraries with your APK.
-#
-# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
-# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
-# is preferred for the same purpose.
-#
-# In order to load a library into your app from Java/Kotlin, you must call
-# System.loadLibrary() and pass the name of the library defined here;
-# for GameActivity/NativeActivity derived applications, the same library name must be
-# used in the AndroidManifest.xml file.
-add_library(${CMAKE_PROJECT_NAME} SHARED
-    # List C/C++ source files with relative paths to this CMakeLists.txt.
-    llama-android.cpp)
-
-# Specifies libraries CMake should link to your target library. You
-# can link libraries from various origins, such as libraries defined in this
-# build script, prebuilt third-party libraries, or Android system libraries.
-target_link_libraries(${CMAKE_PROJECT_NAME}
-    # List libraries link to the target library
-    llama
-    common
-    android
-    log)
diff --git a/examples/llama.android/app/src/main/cpp/llama-android.cpp b/examples/llama.android/app/src/main/cpp/llama-android.cpp
deleted file mode 100644 (file)
index 4af9de3..0000000
+++ /dev/null
@@ -1,443 +0,0 @@
-#include <android/log.h>
-#include <jni.h>
-#include <iomanip>
-#include <math.h>
-#include <string>
-#include <unistd.h>
-#include "llama.h"
-#include "common/common.h"
-
-// Write C++ code here.
-//
-// Do not forget to dynamically load the C++ library into your application.
-//
-// For instance,
-//
-// In MainActivity.java:
-//    static {
-//       System.loadLibrary("llama-android");
-//    }
-//
-// Or, in MainActivity.kt:
-//    companion object {
-//      init {
-//         System.loadLibrary("llama-android")
-//      }
-//    }
-
-#define TAG "llama-android.cpp"
-#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
-#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
-
-jclass la_int_var;
-jmethodID la_int_var_value;
-jmethodID la_int_var_inc;
-
-std::string cached_token_chars;
-
-bool is_valid_utf8(const char * string) {
-    if (!string) {
-        return true;
-    }
-
-    const unsigned char * bytes = (const unsigned char *)string;
-    int num;
-
-    while (*bytes != 0x00) {
-        if ((*bytes & 0x80) == 0x00) {
-            // U+0000 to U+007F
-            num = 1;
-        } else if ((*bytes & 0xE0) == 0xC0) {
-            // U+0080 to U+07FF
-            num = 2;
-        } else if ((*bytes & 0xF0) == 0xE0) {
-            // U+0800 to U+FFFF
-            num = 3;
-        } else if ((*bytes & 0xF8) == 0xF0) {
-            // U+10000 to U+10FFFF
-            num = 4;
-        } else {
-            return false;
-        }
-
-        bytes += 1;
-        for (int i = 1; i < num; ++i) {
-            if ((*bytes & 0xC0) != 0x80) {
-                return false;
-            }
-            bytes += 1;
-        }
-    }
-
-    return true;
-}
-
-static void log_callback(ggml_log_level level, const char * fmt, void * data) {
-    if (level == GGML_LOG_LEVEL_ERROR)     __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
-    else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
-    else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
-    else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
-}
-
-extern "C"
-JNIEXPORT jlong JNICALL
-Java_com_example_llama_Llm_load_1model(JNIEnv *env, jobject, jstring filename) {
-    llama_model_params model_params = llama_model_default_params();
-
-    auto path_to_model = env->GetStringUTFChars(filename, 0);
-    LOGi("Loading model from %s", path_to_model);
-
-    auto model = llama_load_model_from_file(path_to_model, model_params);
-    env->ReleaseStringUTFChars(filename, path_to_model);
-
-    if (!model) {
-        LOGe("load_model() failed");
-        env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
-        return 0;
-    }
-
-    return reinterpret_cast<jlong>(model);
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_free_1model(JNIEnv *, jobject, jlong model) {
-    llama_free_model(reinterpret_cast<llama_model *>(model));
-}
-
-extern "C"
-JNIEXPORT jlong JNICALL
-Java_com_example_llama_Llm_new_1context(JNIEnv *env, jobject, jlong jmodel) {
-    auto model = reinterpret_cast<llama_model *>(jmodel);
-
-    if (!model) {
-        LOGe("new_context(): model cannot be null");
-        env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
-        return 0;
-    }
-
-    int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
-    LOGi("Using %d threads", n_threads);
-
-    llama_context_params ctx_params = llama_context_default_params();
-    ctx_params.seed  = 1234;
-    ctx_params.n_ctx = 2048;
-    ctx_params.n_threads       = n_threads;
-    ctx_params.n_threads_batch = n_threads;
-
-    llama_context * context = llama_new_context_with_model(model, ctx_params);
-
-    if (!context) {
-        LOGe("llama_new_context_with_model() returned null)");
-        env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
-                      "llama_new_context_with_model() returned null)");
-        return 0;
-    }
-
-    return reinterpret_cast<jlong>(context);
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_free_1context(JNIEnv *, jobject, jlong context) {
-    llama_free(reinterpret_cast<llama_context *>(context));
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_backend_1free(JNIEnv *, jobject) {
-    llama_backend_free();
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_log_1to_1android(JNIEnv *, jobject) {
-    llama_log_set(log_callback, NULL);
-}
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_com_example_llama_Llm_bench_1model(
-        JNIEnv *env,
-        jobject,
-        jlong context_pointer,
-        jlong model_pointer,
-        jlong batch_pointer,
-        jint pp,
-        jint tg,
-        jint pl,
-        jint nr
-        ) {
-    auto pp_avg = 0.0;
-    auto tg_avg = 0.0;
-    auto pp_std = 0.0;
-    auto tg_std = 0.0;
-
-    const auto context = reinterpret_cast<llama_context *>(context_pointer);
-    const auto model = reinterpret_cast<llama_model *>(model_pointer);
-    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
-
-    const int n_ctx = llama_n_ctx(context);
-
-    LOGi("n_ctx = %d", n_ctx);
-
-    int i, j;
-    int nri;
-    for (nri = 0; nri < nr; nri++) {
-        LOGi("Benchmark prompt processing (pp)");
-
-        llama_batch_clear(*batch);
-
-        const int n_tokens = pp;
-        for (i = 0; i < n_tokens; i++) {
-            llama_batch_add(*batch, 0, i, { 0 }, false);
-        }
-
-        batch->logits[batch->n_tokens - 1] = true;
-        llama_kv_cache_clear(context);
-
-        const auto t_pp_start = ggml_time_us();
-        if (llama_decode(context, *batch) != 0) {
-            LOGi("llama_decode() failed during prompt processing");
-        }
-        const auto t_pp_end = ggml_time_us();
-
-        // bench text generation
-
-        LOGi("Benchmark text generation (tg)");
-
-        llama_kv_cache_clear(context);
-        const auto t_tg_start = ggml_time_us();
-        for (i = 0; i < tg; i++) {
-
-            llama_batch_clear(*batch);
-            for (j = 0; j < pl; j++) {
-                llama_batch_add(*batch, 0, i, { j }, true);
-            }
-
-            LOGi("llama_decode() text generation: %d", i);
-            if (llama_decode(context, *batch) != 0) {
-                LOGi("llama_decode() failed during text generation");
-            }
-        }
-
-        const auto t_tg_end = ggml_time_us();
-
-        llama_kv_cache_clear(context);
-
-        const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
-        const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
-
-        const auto speed_pp = double(pp) / t_pp;
-        const auto speed_tg = double(pl * tg) / t_tg;
-
-        pp_avg += speed_pp;
-        tg_avg += speed_tg;
-
-        pp_std += speed_pp * speed_pp;
-        tg_std += speed_tg * speed_tg;
-
-        LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
-    }
-
-    pp_avg /= double(nr);
-    tg_avg /= double(nr);
-
-    if (nr > 1) {
-        pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
-        tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
-    } else {
-        pp_std = 0;
-        tg_std = 0;
-    }
-
-    char model_desc[128];
-    llama_model_desc(model, model_desc, sizeof(model_desc));
-
-    const auto model_size     = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
-    const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
-
-    const auto backend    = "(Android)"; // TODO: What should this be?
-
-    std::stringstream result;
-    result << std::setprecision(2);
-    result << "| model | size | params | backend | test | t/s |\n";
-    result << "| --- | --- | --- | --- | --- | --- |\n";
-    result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
-    result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
-
-    return env->NewStringUTF(result.str().c_str());
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
-    llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
-}
-
-extern "C"
-JNIEXPORT jlong JNICALL
-Java_com_example_llama_Llm_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
-
-    // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
-
-    llama_batch *batch = new llama_batch {
-        0,
-        nullptr,
-        nullptr,
-        nullptr,
-        nullptr,
-        nullptr,
-        nullptr,
-        0,
-        0,
-        0,
-    };
-
-    if (embd) {
-        batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
-    } else {
-        batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
-    }
-
-    batch->pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens);
-    batch->n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens);
-    batch->seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
-    for (int i = 0; i < n_tokens; ++i) {
-        batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
-    }
-    batch->logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens);
-
-    return reinterpret_cast<jlong>(batch);
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_backend_1init(JNIEnv *, jobject) {
-    llama_backend_init();
-}
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_com_example_llama_Llm_system_1info(JNIEnv *env, jobject) {
-    return env->NewStringUTF(llama_print_system_info());
-}
-
-extern "C"
-JNIEXPORT jint JNICALL
-Java_com_example_llama_Llm_completion_1init(
-        JNIEnv *env,
-        jobject,
-        jlong context_pointer,
-        jlong batch_pointer,
-        jstring jtext,
-        jint n_len
-    ) {
-
-    cached_token_chars.clear();
-
-    const auto text = env->GetStringUTFChars(jtext, 0);
-    const auto context = reinterpret_cast<llama_context *>(context_pointer);
-    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
-
-    const auto tokens_list = llama_tokenize(context, text, 1);
-
-    auto n_ctx = llama_n_ctx(context);
-    auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
-
-    LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
-
-    if (n_kv_req > n_ctx) {
-        LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
-    }
-
-    for (auto id : tokens_list) {
-        LOGi("%s", llama_token_to_piece(context, id).c_str());
-    }
-
-    llama_batch_clear(*batch);
-
-    // evaluate the initial prompt
-    for (auto i = 0; i < tokens_list.size(); i++) {
-        llama_batch_add(*batch, tokens_list[i], i, { 0 }, false);
-    }
-
-    // llama_decode will output logits only for the last token of the prompt
-    batch->logits[batch->n_tokens - 1] = true;
-
-    if (llama_decode(context, *batch) != 0) {
-        LOGe("llama_decode() failed");
-    }
-
-    env->ReleaseStringUTFChars(jtext, text);
-
-    return batch->n_tokens;
-}
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_com_example_llama_Llm_completion_1loop(
-        JNIEnv * env,
-        jobject,
-        jlong context_pointer,
-        jlong batch_pointer,
-        jint n_len,
-        jobject intvar_ncur
-) {
-    const auto context = reinterpret_cast<llama_context *>(context_pointer);
-    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
-    const auto model = llama_get_model(context);
-
-    if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
-    if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
-    if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
-
-    auto n_vocab = llama_n_vocab(model);
-    auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
-
-    std::vector<llama_token_data> candidates;
-    candidates.reserve(n_vocab);
-
-    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
-    }
-
-    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
-    // sample the most likely token
-    const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
-
-    const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
-    if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
-        return env->NewStringUTF("");
-    }
-
-    auto new_token_chars = llama_token_to_piece(context, new_token_id);
-    cached_token_chars += new_token_chars;
-
-    jstring new_token = nullptr;
-    if (is_valid_utf8(cached_token_chars.c_str())) {
-        new_token = env->NewStringUTF(cached_token_chars.c_str());
-        LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
-        cached_token_chars.clear();
-    } else {
-        new_token = env->NewStringUTF("");
-    }
-
-    llama_batch_clear(*batch);
-    llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
-
-    env->CallVoidMethod(intvar_ncur, la_int_var_inc);
-
-    if (llama_decode(context, *batch) != 0) {
-        LOGe("llama_decode() returned null");
-    }
-
-    return new_token;
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_com_example_llama_Llm_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
-    llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
-}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt
deleted file mode 100644 (file)
index d86afee..0000000
+++ /dev/null
@@ -1,172 +0,0 @@
-package com.example.llama
-
-import android.util.Log
-import kotlinx.coroutines.CoroutineDispatcher
-import kotlinx.coroutines.asCoroutineDispatcher
-import kotlinx.coroutines.flow.Flow
-import kotlinx.coroutines.flow.flow
-import kotlinx.coroutines.flow.flowOn
-import kotlinx.coroutines.withContext
-import java.util.concurrent.Executors
-import kotlin.concurrent.thread
-
-class Llm {
-    private val tag: String? = this::class.simpleName
-
-    private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
-
-    private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
-        thread(start = false, name = "Llm-RunLoop") {
-            Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
-
-            // No-op if called more than once.
-            System.loadLibrary("llama-android")
-
-            // Set llama log handler to Android
-            log_to_android()
-            backend_init(false)
-
-            Log.d(tag, system_info())
-
-            it.run()
-        }.apply {
-            uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
-                Log.e(tag, "Unhandled exception", exception)
-            }
-        }
-    }.asCoroutineDispatcher()
-
-    private val nlen: Int = 64
-
-    private external fun log_to_android()
-    private external fun load_model(filename: String): Long
-    private external fun free_model(model: Long)
-    private external fun new_context(model: Long): Long
-    private external fun free_context(context: Long)
-    private external fun backend_init(numa: Boolean)
-    private external fun backend_free()
-    private external fun free_batch(batch: Long)
-    private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
-    private external fun bench_model(
-        context: Long,
-        model: Long,
-        batch: Long,
-        pp: Int,
-        tg: Int,
-        pl: Int,
-        nr: Int
-    ): String
-
-    private external fun system_info(): String
-
-    private external fun completion_init(
-        context: Long,
-        batch: Long,
-        text: String,
-        nLen: Int
-    ): Int
-
-    private external fun completion_loop(
-        context: Long,
-        batch: Long,
-        nLen: Int,
-        ncur: IntVar
-    ): String?
-
-    private external fun kv_cache_clear(context: Long)
-
-    suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
-        return withContext(runLoop) {
-            when (val state = threadLocalState.get()) {
-                is State.Loaded -> {
-                    Log.d(tag, "bench(): $state")
-                    bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
-                }
-
-                else -> throw IllegalStateException("No model loaded")
-            }
-        }
-    }
-
-    suspend fun load(pathToModel: String) {
-        withContext(runLoop) {
-            when (threadLocalState.get()) {
-                is State.Idle -> {
-                    val model = load_model(pathToModel)
-                    if (model == 0L)  throw IllegalStateException("load_model() failed")
-
-                    val context = new_context(model)
-                    if (context == 0L) throw IllegalStateException("new_context() failed")
-
-                    val batch = new_batch(512, 0, 1)
-                    if (batch == 0L) throw IllegalStateException("new_batch() failed")
-
-                    Log.i(tag, "Loaded model $pathToModel")
-                    threadLocalState.set(State.Loaded(model, context, batch))
-                }
-                else -> throw IllegalStateException("Model already loaded")
-            }
-        }
-    }
-
-    fun send(message: String): Flow<String> = flow {
-        when (val state = threadLocalState.get()) {
-            is State.Loaded -> {
-                val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
-                while (ncur.value <= nlen) {
-                    val str = completion_loop(state.context, state.batch, nlen, ncur)
-                    if (str == null) {
-                        break
-                    }
-                    emit(str)
-                }
-                kv_cache_clear(state.context)
-            }
-            else -> {}
-        }
-    }.flowOn(runLoop)
-
-    /**
-     * Unloads the model and frees resources.
-     *
-     * This is a no-op if there's no model loaded.
-     */
-    suspend fun unload() {
-        withContext(runLoop) {
-            when (val state = threadLocalState.get()) {
-                is State.Loaded -> {
-                    free_context(state.context)
-                    free_model(state.model)
-                    free_batch(state.batch)
-
-                    threadLocalState.set(State.Idle)
-                }
-                else -> {}
-            }
-        }
-    }
-
-    companion object {
-        private class IntVar(value: Int) {
-            @Volatile
-            var value: Int = value
-                private set
-
-            fun inc() {
-                synchronized(this) {
-                    value += 1
-                }
-            }
-        }
-
-        private sealed interface State {
-            data object Idle: State
-            data class Loaded(val model: Long, val context: Long, val batch: Long): State
-        }
-
-        // Enforce only one instance of Llm.
-        private val _instance: Llm = Llm()
-
-        fun instance(): Llm = _instance
-    }
-}
index be95e22218332484b7e902cba73e54e9f3d307b9..45ac29938f441e91c95620eef19be76e38008a3c 100644 (file)
@@ -1,5 +1,6 @@
 package com.example.llama
 
+import android.llama.cpp.LLamaAndroid
 import android.util.Log
 import androidx.compose.runtime.getValue
 import androidx.compose.runtime.mutableStateOf
@@ -9,7 +10,7 @@ import androidx.lifecycle.viewModelScope
 import kotlinx.coroutines.flow.catch
 import kotlinx.coroutines.launch
 
-class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
+class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() {
     companion object {
         @JvmStatic
         private val NanosPerSecond = 1_000_000_000.0
@@ -28,7 +29,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
 
         viewModelScope.launch {
             try {
-                llm.unload()
+                llamaAndroid.unload()
             } catch (exc: IllegalStateException) {
                 messages += exc.message!!
             }
@@ -44,7 +45,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
         messages += ""
 
         viewModelScope.launch {
-            llm.send(text)
+            llamaAndroid.send(text)
                 .catch {
                     Log.e(tag, "send() failed", it)
                     messages += it.message!!
@@ -57,7 +58,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
         viewModelScope.launch {
             try {
                 val start = System.nanoTime()
-                val warmupResult = llm.bench(pp, tg, pl, nr)
+                val warmupResult = llamaAndroid.bench(pp, tg, pl, nr)
                 val end = System.nanoTime()
 
                 messages += warmupResult
@@ -70,7 +71,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
                     return@launch
                 }
 
-                messages += llm.bench(512, 128, 1, 3)
+                messages += llamaAndroid.bench(512, 128, 1, 3)
             } catch (exc: IllegalStateException) {
                 Log.e(tag, "bench() failed", exc)
                 messages += exc.message!!
@@ -81,7 +82,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
     fun load(pathToModel: String) {
         viewModelScope.launch {
             try {
-                llm.load(pathToModel)
+                llamaAndroid.load(pathToModel)
                 messages += "Loaded $pathToModel"
             } catch (exc: IllegalStateException) {
                 Log.e(tag, "load() failed", exc)
index 50ebc821122f6e196eac128c65e1913bb3c855fa..acd1ada7d9b1a4e636f2d0a5a93653d66f8cb954 100644 (file)
@@ -2,4 +2,5 @@
 plugins {
     id("com.android.application") version "8.2.0" apply false
     id("org.jetbrains.kotlin.android") version "1.9.0" apply false
+    id("com.android.library") version "8.2.0" apply false
 }
diff --git a/examples/llama.android/llama/.gitignore b/examples/llama.android/llama/.gitignore
new file mode 100644 (file)
index 0000000..796b96d
--- /dev/null
@@ -0,0 +1 @@
+/build
diff --git a/examples/llama.android/llama/CMakeLists.txt b/examples/llama.android/llama/CMakeLists.txt
new file mode 100644 (file)
index 0000000..a5618ca
--- /dev/null
@@ -0,0 +1,55 @@
+
+# For more information about using CMake with Android Studio, read the
+# documentation: https://d.android.com/studio/projects/add-native-code.html.
+# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
+
+# Sets the minimum CMake version required for this project.
+cmake_minimum_required(VERSION 3.22.1)
+
+# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
+# Since this is the top level CMakeLists.txt, the project name is also accessible
+# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
+# build script scope).
+project("llama-android")
+
+## Fetch latest llama.cpp from GitHub
+#include(FetchContent)
+#FetchContent_Declare(
+#        llama
+#        GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
+#        GIT_TAG        master
+#)
+#
+## Also provides "common"
+#FetchContent_MakeAvailable(llama)
+
+# llama.cpp CI uses the code from the current branch
+# ref: https://github.com/ggerganov/llama.cpp/pull/7341#issuecomment-2117617700
+add_subdirectory(../../../../../../ build-llama)
+
+# Creates and names a library, sets it as either STATIC
+# or SHARED, and provides the relative paths to its source code.
+# You can define multiple libraries, and CMake builds them for you.
+# Gradle automatically packages shared libraries with your APK.
+#
+# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
+# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
+# is preferred for the same purpose.
+#
+# In order to load a library into your app from Java/Kotlin, you must call
+# System.loadLibrary() and pass the name of the library defined here;
+# for GameActivity/NativeActivity derived applications, the same library name must be
+# used in the AndroidManifest.xml file.
+add_library(${CMAKE_PROJECT_NAME} SHARED
+    # List C/C++ source files with relative paths to this CMakeLists.txt.
+        llama-android.cpp)
+
+# Specifies libraries CMake should link to your target library. You
+# can link libraries from various origins, such as libraries defined in this
+# build script, prebuilt third-party libraries, or Android system libraries.
+target_link_libraries(${CMAKE_PROJECT_NAME}
+    # List libraries link to the target library
+    llama
+    common
+    android
+    log)
diff --git a/examples/llama.android/llama/build.gradle.kts b/examples/llama.android/llama/build.gradle.kts
new file mode 100644 (file)
index 0000000..0a38061
--- /dev/null
@@ -0,0 +1,68 @@
+plugins {
+    id("com.android.library")
+    id("org.jetbrains.kotlin.android")
+}
+
+android {
+    namespace = "android.llama.cpp"
+    compileSdk = 34
+
+    defaultConfig {
+        minSdk = 33
+
+        testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
+        consumerProguardFiles("consumer-rules.pro")
+        ndk {
+            // Add NDK properties if wanted, e.g.
+            // abiFilters += listOf("arm64-v8a")
+        }
+        externalNativeBuild {
+            cmake {
+                arguments += "-DCMAKE_BUILD_TYPE=Release"
+                cppFlags += listOf()
+                arguments += listOf()
+
+                cppFlags("")
+            }
+        }
+    }
+
+    buildTypes {
+        release {
+            isMinifyEnabled = false
+            proguardFiles(
+                getDefaultProguardFile("proguard-android-optimize.txt"),
+                "proguard-rules.pro"
+            )
+        }
+    }
+    externalNativeBuild {
+        cmake {
+            path("src/main/cpp/CMakeLists.txt")
+            version = "3.22.1"
+        }
+    }
+    compileOptions {
+        sourceCompatibility = JavaVersion.VERSION_1_8
+        targetCompatibility = JavaVersion.VERSION_1_8
+    }
+    kotlinOptions {
+        jvmTarget = "1.8"
+    }
+
+    packaging {
+        resources {
+            excludes += "/META-INF/{AL2.0,LGPL2.1}"
+        }
+    }
+}
+
+dependencies {
+
+    implementation("androidx.core:core-ktx:1.12.0")
+    implementation("androidx.appcompat:appcompat:1.6.1")
+    implementation("com.google.android.material:material:1.11.0")
+    testImplementation("junit:junit:4.13.2")
+    androidTestImplementation("androidx.test.ext:junit:1.1.5")
+    androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
+}
diff --git a/examples/llama.android/llama/consumer-rules.pro b/examples/llama.android/llama/consumer-rules.pro
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/examples/llama.android/llama/proguard-rules.pro b/examples/llama.android/llama/proguard-rules.pro
new file mode 100644 (file)
index 0000000..f1b4245
--- /dev/null
@@ -0,0 +1,21 @@
+# Add project specific ProGuard rules here.
+# You can control the set of applied configuration files using the
+# proguardFiles setting in build.gradle.
+#
+# For more details, see
+#   http://developer.android.com/guide/developing/tools/proguard.html
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+#   public *;
+#}
+
+# Uncomment this to preserve the line number information for
+# debugging stack traces.
+#-keepattributes SourceFile,LineNumberTable
+
+# If you keep the line number information, uncomment this to
+# hide the original source file name.
+#-renamesourcefileattribute SourceFile
diff --git a/examples/llama.android/llama/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt b/examples/llama.android/llama/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt
new file mode 100644 (file)
index 0000000..05d6ab5
--- /dev/null
@@ -0,0 +1,24 @@
+package android.llama.cpp
+
+import androidx.test.platform.app.InstrumentationRegistry
+import androidx.test.ext.junit.runners.AndroidJUnit4
+
+import org.junit.Test
+import org.junit.runner.RunWith
+
+import org.junit.Assert.*
+
+/**
+ * Instrumented test, which will execute on an Android device.
+ *
+ * See [testing documentation](http://d.android.com/tools/testing).
+ */
+@RunWith(AndroidJUnit4::class)
+class ExampleInstrumentedTest {
+    @Test
+    fun useAppContext() {
+        // Context of the app under test.
+        val appContext = InstrumentationRegistry.getInstrumentation().targetContext
+        assertEquals("android.llama.cpp.test", appContext.packageName)
+    }
+}
diff --git a/examples/llama.android/llama/src/main/AndroidManifest.xml b/examples/llama.android/llama/src/main/AndroidManifest.xml
new file mode 100644 (file)
index 0000000..8bdb7e1
--- /dev/null
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android">
+
+</manifest>
diff --git a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt
new file mode 100644 (file)
index 0000000..42ebaad
--- /dev/null
@@ -0,0 +1,49 @@
+# For more information about using CMake with Android Studio, read the
+# documentation: https://d.android.com/studio/projects/add-native-code.html.
+# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
+
+# Sets the minimum CMake version required for this project.
+cmake_minimum_required(VERSION 3.22.1)
+
+# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
+# Since this is the top level CMakeLists.txt, the project name is also accessible
+# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
+# build script scope).
+project("llama-android")
+
+include(FetchContent)
+FetchContent_Declare(
+        llama
+        GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
+        GIT_TAG        master
+)
+
+# Also provides "common"
+FetchContent_MakeAvailable(llama)
+
+# Creates and names a library, sets it as either STATIC
+# or SHARED, and provides the relative paths to its source code.
+# You can define multiple libraries, and CMake builds them for you.
+# Gradle automatically packages shared libraries with your APK.
+#
+# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
+# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
+# is preferred for the same purpose.
+#
+# In order to load a library into your app from Java/Kotlin, you must call
+# System.loadLibrary() and pass the name of the library defined here;
+# for GameActivity/NativeActivity derived applications, the same library name must be
+# used in the AndroidManifest.xml file.
+add_library(${CMAKE_PROJECT_NAME} SHARED
+        # List C/C++ source files with relative paths to this CMakeLists.txt.
+        llama-android.cpp)
+
+# Specifies libraries CMake should link to your target library. You
+# can link libraries from various origins, such as libraries defined in this
+# build script, prebuilt third-party libraries, or Android system libraries.
+target_link_libraries(${CMAKE_PROJECT_NAME}
+        # List libraries link to the target library
+        llama
+        common
+        android
+        log)
diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp
new file mode 100644 (file)
index 0000000..874158e
--- /dev/null
@@ -0,0 +1,443 @@
+#include <android/log.h>
+#include <jni.h>
+#include <iomanip>
+#include <math.h>
+#include <string>
+#include <unistd.h>
+#include "llama.h"
+#include "common/common.h"
+
+// Write C++ code here.
+//
+// Do not forget to dynamically load the C++ library into your application.
+//
+// For instance,
+//
+// In MainActivity.java:
+//    static {
+//       System.loadLibrary("llama-android");
+//    }
+//
+// Or, in MainActivity.kt:
+//    companion object {
+//      init {
+//         System.loadLibrary("llama-android")
+//      }
+//    }
+
+#define TAG "llama-android.cpp"
+#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
+#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
+
+jclass la_int_var;
+jmethodID la_int_var_value;
+jmethodID la_int_var_inc;
+
+std::string cached_token_chars;
+
+bool is_valid_utf8(const char * string) {
+    if (!string) {
+        return true;
+    }
+
+    const unsigned char * bytes = (const unsigned char *)string;
+    int num;
+
+    while (*bytes != 0x00) {
+        if ((*bytes & 0x80) == 0x00) {
+            // U+0000 to U+007F
+            num = 1;
+        } else if ((*bytes & 0xE0) == 0xC0) {
+            // U+0080 to U+07FF
+            num = 2;
+        } else if ((*bytes & 0xF0) == 0xE0) {
+            // U+0800 to U+FFFF
+            num = 3;
+        } else if ((*bytes & 0xF8) == 0xF0) {
+            // U+10000 to U+10FFFF
+            num = 4;
+        } else {
+            return false;
+        }
+
+        bytes += 1;
+        for (int i = 1; i < num; ++i) {
+            if ((*bytes & 0xC0) != 0x80) {
+                return false;
+            }
+            bytes += 1;
+        }
+    }
+
+    return true;
+}
+
+static void log_callback(ggml_log_level level, const char * fmt, void * data) {
+    if (level == GGML_LOG_LEVEL_ERROR)     __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
+    else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
+    else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
+    else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
+}
+
+extern "C"
+JNIEXPORT jlong JNICALL
+Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
+    llama_model_params model_params = llama_model_default_params();
+
+    auto path_to_model = env->GetStringUTFChars(filename, 0);
+    LOGi("Loading model from %s", path_to_model);
+
+    auto model = llama_load_model_from_file(path_to_model, model_params);
+    env->ReleaseStringUTFChars(filename, path_to_model);
+
+    if (!model) {
+        LOGe("load_model() failed");
+        env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
+        return 0;
+    }
+
+    return reinterpret_cast<jlong>(model);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
+    llama_free_model(reinterpret_cast<llama_model *>(model));
+}
+
+extern "C"
+JNIEXPORT jlong JNICALL
+Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) {
+    auto model = reinterpret_cast<llama_model *>(jmodel);
+
+    if (!model) {
+        LOGe("new_context(): model cannot be null");
+        env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
+        return 0;
+    }
+
+    int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
+    LOGi("Using %d threads", n_threads);
+
+    llama_context_params ctx_params = llama_context_default_params();
+    ctx_params.seed  = 1234;
+    ctx_params.n_ctx = 2048;
+    ctx_params.n_threads       = n_threads;
+    ctx_params.n_threads_batch = n_threads;
+
+    llama_context * context = llama_new_context_with_model(model, ctx_params);
+
+    if (!context) {
+        LOGe("llama_new_context_with_model() returned null)");
+        env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
+                      "llama_new_context_with_model() returned null)");
+        return 0;
+    }
+
+    return reinterpret_cast<jlong>(context);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) {
+    llama_free(reinterpret_cast<llama_context *>(context));
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
+    llama_backend_free();
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
+    llama_log_set(log_callback, NULL);
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_android_llama_cpp_LLamaAndroid_bench_1model(
+        JNIEnv *env,
+        jobject,
+        jlong context_pointer,
+        jlong model_pointer,
+        jlong batch_pointer,
+        jint pp,
+        jint tg,
+        jint pl,
+        jint nr
+        ) {
+    auto pp_avg = 0.0;
+    auto tg_avg = 0.0;
+    auto pp_std = 0.0;
+    auto tg_std = 0.0;
+
+    const auto context = reinterpret_cast<llama_context *>(context_pointer);
+    const auto model = reinterpret_cast<llama_model *>(model_pointer);
+    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
+
+    const int n_ctx = llama_n_ctx(context);
+
+    LOGi("n_ctx = %d", n_ctx);
+
+    int i, j;
+    int nri;
+    for (nri = 0; nri < nr; nri++) {
+        LOGi("Benchmark prompt processing (pp)");
+
+        llama_batch_clear(*batch);
+
+        const int n_tokens = pp;
+        for (i = 0; i < n_tokens; i++) {
+            llama_batch_add(*batch, 0, i, { 0 }, false);
+        }
+
+        batch->logits[batch->n_tokens - 1] = true;
+        llama_kv_cache_clear(context);
+
+        const auto t_pp_start = ggml_time_us();
+        if (llama_decode(context, *batch) != 0) {
+            LOGi("llama_decode() failed during prompt processing");
+        }
+        const auto t_pp_end = ggml_time_us();
+
+        // bench text generation
+
+        LOGi("Benchmark text generation (tg)");
+
+        llama_kv_cache_clear(context);
+        const auto t_tg_start = ggml_time_us();
+        for (i = 0; i < tg; i++) {
+
+            llama_batch_clear(*batch);
+            for (j = 0; j < pl; j++) {
+                llama_batch_add(*batch, 0, i, { j }, true);
+            }
+
+            LOGi("llama_decode() text generation: %d", i);
+            if (llama_decode(context, *batch) != 0) {
+                LOGi("llama_decode() failed during text generation");
+            }
+        }
+
+        const auto t_tg_end = ggml_time_us();
+
+        llama_kv_cache_clear(context);
+
+        const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
+        const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
+
+        const auto speed_pp = double(pp) / t_pp;
+        const auto speed_tg = double(pl * tg) / t_tg;
+
+        pp_avg += speed_pp;
+        tg_avg += speed_tg;
+
+        pp_std += speed_pp * speed_pp;
+        tg_std += speed_tg * speed_tg;
+
+        LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
+    }
+
+    pp_avg /= double(nr);
+    tg_avg /= double(nr);
+
+    if (nr > 1) {
+        pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
+        tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
+    } else {
+        pp_std = 0;
+        tg_std = 0;
+    }
+
+    char model_desc[128];
+    llama_model_desc(model, model_desc, sizeof(model_desc));
+
+    const auto model_size     = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
+    const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
+
+    const auto backend    = "(Android)"; // TODO: What should this be?
+
+    std::stringstream result;
+    result << std::setprecision(2);
+    result << "| model | size | params | backend | test | t/s |\n";
+    result << "| --- | --- | --- | --- | --- | --- |\n";
+    result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
+    result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
+
+    return env->NewStringUTF(result.str().c_str());
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
+    llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
+}
+
+extern "C"
+JNIEXPORT jlong JNICALL
+Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
+
+    // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
+
+    llama_batch *batch = new llama_batch {
+        0,
+        nullptr,
+        nullptr,
+        nullptr,
+        nullptr,
+        nullptr,
+        nullptr,
+        0,
+        0,
+        0,
+    };
+
+    if (embd) {
+        batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
+    } else {
+        batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
+    }
+
+    batch->pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens);
+    batch->n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens);
+    batch->seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
+    for (int i = 0; i < n_tokens; ++i) {
+        batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
+    }
+    batch->logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens);
+
+    return reinterpret_cast<jlong>(batch);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
+    llama_backend_init();
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) {
+    return env->NewStringUTF(llama_print_system_info());
+}
+
+extern "C"
+JNIEXPORT jint JNICALL
+Java_android_llama_cpp_LLamaAndroid_completion_1init(
+        JNIEnv *env,
+        jobject,
+        jlong context_pointer,
+        jlong batch_pointer,
+        jstring jtext,
+        jint n_len
+    ) {
+
+    cached_token_chars.clear();
+
+    const auto text = env->GetStringUTFChars(jtext, 0);
+    const auto context = reinterpret_cast<llama_context *>(context_pointer);
+    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
+
+    const auto tokens_list = llama_tokenize(context, text, 1);
+
+    auto n_ctx = llama_n_ctx(context);
+    auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
+
+    LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
+
+    if (n_kv_req > n_ctx) {
+        LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
+    }
+
+    for (auto id : tokens_list) {
+        LOGi("%s", llama_token_to_piece(context, id).c_str());
+    }
+
+    llama_batch_clear(*batch);
+
+    // evaluate the initial prompt
+    for (auto i = 0; i < tokens_list.size(); i++) {
+        llama_batch_add(*batch, tokens_list[i], i, { 0 }, false);
+    }
+
+    // llama_decode will output logits only for the last token of the prompt
+    batch->logits[batch->n_tokens - 1] = true;
+
+    if (llama_decode(context, *batch) != 0) {
+        LOGe("llama_decode() failed");
+    }
+
+    env->ReleaseStringUTFChars(jtext, text);
+
+    return batch->n_tokens;
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_android_llama_cpp_LLamaAndroid_completion_1loop(
+        JNIEnv * env,
+        jobject,
+        jlong context_pointer,
+        jlong batch_pointer,
+        jint n_len,
+        jobject intvar_ncur
+) {
+    const auto context = reinterpret_cast<llama_context *>(context_pointer);
+    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
+    const auto model = llama_get_model(context);
+
+    if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
+    if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
+    if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
+
+    auto n_vocab = llama_n_vocab(model);
+    auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
+
+    std::vector<llama_token_data> candidates;
+    candidates.reserve(n_vocab);
+
+    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+        candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
+    }
+
+    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+
+    // sample the most likely token
+    const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
+
+    const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
+    if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
+        return env->NewStringUTF("");
+    }
+
+    auto new_token_chars = llama_token_to_piece(context, new_token_id);
+    cached_token_chars += new_token_chars;
+
+    jstring new_token = nullptr;
+    if (is_valid_utf8(cached_token_chars.c_str())) {
+        new_token = env->NewStringUTF(cached_token_chars.c_str());
+        LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
+        cached_token_chars.clear();
+    } else {
+        new_token = env->NewStringUTF("");
+    }
+
+    llama_batch_clear(*batch);
+    llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
+
+    env->CallVoidMethod(intvar_ncur, la_int_var_inc);
+
+    if (llama_decode(context, *batch) != 0) {
+        LOGe("llama_decode() returned null");
+    }
+
+    return new_token;
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
+    llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
+}
diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt
new file mode 100644 (file)
index 0000000..6c63e54
--- /dev/null
@@ -0,0 +1,172 @@
+package android.llama.cpp
+
+import android.util.Log
+import kotlinx.coroutines.CoroutineDispatcher
+import kotlinx.coroutines.asCoroutineDispatcher
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.flow
+import kotlinx.coroutines.flow.flowOn
+import kotlinx.coroutines.withContext
+import java.util.concurrent.Executors
+import kotlin.concurrent.thread
+
+class LLamaAndroid {
+    private val tag: String? = this::class.simpleName
+
+    private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
+
+    private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
+        thread(start = false, name = "Llm-RunLoop") {
+            Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
+
+            // No-op if called more than once.
+            System.loadLibrary("llama-android")
+
+            // Set llama log handler to Android
+            log_to_android()
+            backend_init(false)
+
+            Log.d(tag, system_info())
+
+            it.run()
+        }.apply {
+            uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
+                Log.e(tag, "Unhandled exception", exception)
+            }
+        }
+    }.asCoroutineDispatcher()
+
+    private val nlen: Int = 64
+
+    private external fun log_to_android()
+    private external fun load_model(filename: String): Long
+    private external fun free_model(model: Long)
+    private external fun new_context(model: Long): Long
+    private external fun free_context(context: Long)
+    private external fun backend_init(numa: Boolean)
+    private external fun backend_free()
+    private external fun free_batch(batch: Long)
+    private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
+    private external fun bench_model(
+        context: Long,
+        model: Long,
+        batch: Long,
+        pp: Int,
+        tg: Int,
+        pl: Int,
+        nr: Int
+    ): String
+
+    private external fun system_info(): String
+
+    private external fun completion_init(
+        context: Long,
+        batch: Long,
+        text: String,
+        nLen: Int
+    ): Int
+
+    private external fun completion_loop(
+        context: Long,
+        batch: Long,
+        nLen: Int,
+        ncur: IntVar
+    ): String?
+
+    private external fun kv_cache_clear(context: Long)
+
+    suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
+        return withContext(runLoop) {
+            when (val state = threadLocalState.get()) {
+                is State.Loaded -> {
+                    Log.d(tag, "bench(): $state")
+                    bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
+                }
+
+                else -> throw IllegalStateException("No model loaded")
+            }
+        }
+    }
+
+    suspend fun load(pathToModel: String) {
+        withContext(runLoop) {
+            when (threadLocalState.get()) {
+                is State.Idle -> {
+                    val model = load_model(pathToModel)
+                    if (model == 0L)  throw IllegalStateException("load_model() failed")
+
+                    val context = new_context(model)
+                    if (context == 0L) throw IllegalStateException("new_context() failed")
+
+                    val batch = new_batch(512, 0, 1)
+                    if (batch == 0L) throw IllegalStateException("new_batch() failed")
+
+                    Log.i(tag, "Loaded model $pathToModel")
+                    threadLocalState.set(State.Loaded(model, context, batch))
+                }
+                else -> throw IllegalStateException("Model already loaded")
+            }
+        }
+    }
+
+    fun send(message: String): Flow<String> = flow {
+        when (val state = threadLocalState.get()) {
+            is State.Loaded -> {
+                val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
+                while (ncur.value <= nlen) {
+                    val str = completion_loop(state.context, state.batch, nlen, ncur)
+                    if (str == null) {
+                        break
+                    }
+                    emit(str)
+                }
+                kv_cache_clear(state.context)
+            }
+            else -> {}
+        }
+    }.flowOn(runLoop)
+
+    /**
+     * Unloads the model and frees resources.
+     *
+     * This is a no-op if there's no model loaded.
+     */
+    suspend fun unload() {
+        withContext(runLoop) {
+            when (val state = threadLocalState.get()) {
+                is State.Loaded -> {
+                    free_context(state.context)
+                    free_model(state.model)
+                    free_batch(state.batch)
+
+                    threadLocalState.set(State.Idle)
+                }
+                else -> {}
+            }
+        }
+    }
+
+    companion object {
+        private class IntVar(value: Int) {
+            @Volatile
+            var value: Int = value
+                private set
+
+            fun inc() {
+                synchronized(this) {
+                    value += 1
+                }
+            }
+        }
+
+        private sealed interface State {
+            data object Idle: State
+            data class Loaded(val model: Long, val context: Long, val batch: Long): State
+        }
+
+        // Enforce only one instance of Llm.
+        private val _instance: LLamaAndroid = LLamaAndroid()
+
+        fun instance(): LLamaAndroid = _instance
+    }
+}
diff --git a/examples/llama.android/llama/src/test/java/android/llama/cpp/ExampleUnitTest.kt b/examples/llama.android/llama/src/test/java/android/llama/cpp/ExampleUnitTest.kt
new file mode 100644 (file)
index 0000000..cbbb974
--- /dev/null
@@ -0,0 +1,17 @@
+package android.llama.cpp
+
+import org.junit.Test
+
+import org.junit.Assert.*
+
+/**
+ * Example local unit test, which will execute on the development machine (host).
+ *
+ * See [testing documentation](http://d.android.com/tools/testing).
+ */
+class ExampleUnitTest {
+    @Test
+    fun addition_isCorrect() {
+        assertEquals(4, 2 + 2)
+    }
+}
index 2ba32c4fafc5c968f09cd01288fed48117720eef..c7c1a034a45b8e14698bc415b5a3ee902e6684f1 100644 (file)
@@ -15,3 +15,4 @@ dependencyResolutionManagement {
 
 rootProject.name = "LlamaAndroid"
 include(":app")
+include(":llama")