]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper.android : support benchmark for Android example. (#542)
authorTakeshi Inoue <redacted>
Tue, 7 Mar 2023 19:36:30 +0000 (04:36 +0900)
committerGitHub <redacted>
Tue, 7 Mar 2023 19:36:30 +0000 (21:36 +0200)
* whisper.android: Support benchmark for Android example.

* whisper.android: update screenshot in README.

* update: Make text selectable for copy & paste.

* Update whisper.h to restore API name

Co-authored-by: Georgi Gerganov <redacted>
* whisper.android: Restore original API names.

---------

Co-authored-by: tinoue <redacted>
Co-authored-by: Georgi Gerganov <redacted>
examples/whisper.android/README.md
examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreen.kt
examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt
examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt
examples/whisper.android/app/src/main/jni/whisper/jni.c
whisper.cpp
whisper.h

index ae47d401a06547b7d98ec76fc9b4fdde007e28e2..57ee78daa7a03096308848769ff0f0160a2fcf4f 100644 (file)
@@ -9,4 +9,4 @@ To use:
 5. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
 [^1]: I recommend the tiny or base models for running on an Android device.
 
-<img width="300" alt="image" src="https://user-images.githubusercontent.com/1991296/208154256-82d972dc-221b-48c4-bfcb-36ce68602f93.png">
+<img width="300" alt="image" src="https://user-images.githubusercontent.com/1670775/221613663-a17bf770-27ef-45ab-9a46-a5f99ba65d2a.jpg">
index f05f56cb3bb56dda5384d6c3236146155cf5c998..30128f3acfedc048360886de620ead1c5e2a3cb0 100644 (file)
@@ -2,6 +2,7 @@ package com.whispercppdemo.ui.main
 
 import androidx.compose.foundation.layout.*
 import androidx.compose.foundation.rememberScrollState
+import androidx.compose.foundation.text.selection.SelectionContainer
 import androidx.compose.foundation.verticalScroll
 import androidx.compose.material3.*
 import androidx.compose.runtime.Composable
@@ -19,6 +20,7 @@ fun MainScreen(viewModel: MainScreenViewModel) {
         canTranscribe = viewModel.canTranscribe,
         isRecording = viewModel.isRecording,
         messageLog = viewModel.dataLog,
+        onBenchmarkTapped = viewModel::benchmark,
         onTranscribeSampleTapped = viewModel::transcribeSample,
         onRecordTapped = viewModel::toggleRecord
     )
@@ -30,6 +32,7 @@ private fun MainScreen(
     canTranscribe: Boolean,
     isRecording: Boolean,
     messageLog: String,
+    onBenchmarkTapped: () -> Unit,
     onTranscribeSampleTapped: () -> Unit,
     onRecordTapped: () -> Unit
 ) {
@@ -45,8 +48,11 @@ private fun MainScreen(
                 .padding(innerPadding)
                 .padding(16.dp)
         ) {
-            Row(horizontalArrangement = Arrangement.SpaceBetween) {
-                TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
+            Column(verticalArrangement = Arrangement.SpaceBetween) {
+                Row(horizontalArrangement = Arrangement.SpaceBetween, modifier = Modifier.fillMaxWidth()) {
+                    BenchmarkButton(enabled = canTranscribe, onClick = onBenchmarkTapped)
+                    TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
+                }
                 RecordButton(
                     enabled = canTranscribe,
                     isRecording = isRecording,
@@ -60,7 +66,16 @@ private fun MainScreen(
 
 @Composable
 private fun MessageLog(log: String) {
-    Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
+    SelectionContainer() {
+        Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
+    }
+}
+
+@Composable
+private fun BenchmarkButton(enabled: Boolean, onClick: () -> Unit) {
+    Button(onClick = onClick, enabled = enabled) {
+        Text("Benchmark")
+    }
 }
 
 @Composable
index 29ac2cd3bd426586ec538857356b25a701c92d5f..269f0c2a15c83ae836756e428b5580e809bf52c1 100644 (file)
@@ -41,10 +41,15 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
 
     init {
         viewModelScope.launch {
+            printSystemInfo()
             loadData()
         }
     }
 
+    private suspend fun printSystemInfo() {
+        printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()));
+    }
+
     private suspend fun loadData() {
         printMessage("Loading data...\n")
         try {
@@ -81,10 +86,29 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
         //whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
     }
 
+    fun benchmark() = viewModelScope.launch {
+        runBenchmark(6)
+    }
+
     fun transcribeSample() = viewModelScope.launch {
         transcribeAudio(getFirstSample())
     }
 
+    private suspend fun runBenchmark(nthreads: Int) {
+        if (!canTranscribe) {
+            return
+        }
+
+        canTranscribe = false
+
+        printMessage("Running benchmark. This will take minutes...\n")
+        whisperContext?.benchMemory(nthreads)?.let{ printMessage(it) }
+        printMessage("\n")
+        whisperContext?.benchGgmlMulMat(nthreads)?.let{ printMessage(it) }
+
+        canTranscribe = true
+    }
+
     private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) {
         samplesPath.listFiles()!!.first()
     }
@@ -114,11 +138,14 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
         canTranscribe = false
 
         try {
-            printMessage("Reading wave samples...\n")
+            printMessage("Reading wave samples... ")
             val data = readAudioSamples(file)
+            printMessage("${data.size / (16000 / 1000)} ms\n")
             printMessage("Transcribing data...\n")
+            val start = System.currentTimeMillis()
             val text = whisperContext?.transcribeData(data)
-            printMessage("Done: $text\n")
+            val elapsed = System.currentTimeMillis() - start
+            printMessage("Done ($elapsed ms): $text\n")
         } catch (e: Exception) {
             Log.w(LOG_TAG, e)
             printMessage("${e.localizedMessage}\n")
index b0b42003d422a027be14f754518a3fd02a7191c8..a2b651c72ab5259684509c0d3b066cbe8547833d 100644 (file)
@@ -27,6 +27,14 @@ class WhisperContext private constructor(private var ptr: Long) {
         }
     }
 
+    suspend fun benchMemory(nthreads: Int): String = withContext(scope.coroutineContext) {
+        return@withContext WhisperLib.benchMemcpy(nthreads)
+    }
+
+    suspend fun benchGgmlMulMat(nthreads: Int): String = withContext(scope.coroutineContext) {
+        return@withContext WhisperLib.benchGgmlMulMat(nthreads)
+    }
+
     suspend fun release() = withContext(scope.coroutineContext) {
         if (ptr != 0L) {
             WhisperLib.freeContext(ptr)
@@ -66,6 +74,10 @@ class WhisperContext private constructor(private var ptr: Long) {
             }
             return WhisperContext(ptr)
         }
+
+        fun getSystemInfo(): String {
+            return WhisperLib.getSystemInfo()
+        }
     }
 }
 
@@ -117,6 +129,9 @@ private class WhisperLib {
         external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
         external fun getTextSegmentCount(contextPtr: Long): Int
         external fun getTextSegment(contextPtr: Long, index: Int): String
+        external fun getSystemInfo(): String
+        external fun benchMemcpy(nthread: Int): String
+        external fun benchGgmlMulMat(nthread: Int): String
     }
 }
 
index 160fe78f29aceb4a6a3943c7b0fee0f50d3ffa94..82dfd77808e94678d575a7489e1b10fa5e4bdf87 100644 (file)
@@ -6,6 +6,7 @@
 #include <sys/sysinfo.h>
 #include <string.h>
 #include "whisper.h"
+#include "ggml.h"
 
 #define UNUSED(x) (void)(x)
 #define TAG "JNI"
@@ -213,4 +214,30 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegment(
     const char *text = whisper_full_get_segment_text(context, index);
     jstring string = (*env)->NewStringUTF(env, text);
     return string;
-}
\ No newline at end of file
+}
+
+JNIEXPORT jstring JNICALL
+Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getSystemInfo(
+        JNIEnv *env, jobject thiz
+) {
+    UNUSED(thiz);
+    const char *sysinfo = whisper_print_system_info();
+    jstring string = (*env)->NewStringUTF(env, sysinfo);
+    return string;
+}
+
+JNIEXPORT jstring JNICALL
+Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchMemcpy(JNIEnv *env, jobject thiz,
+                                                                      jint n_threads) {
+    UNUSED(thiz);
+    const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads);
+    jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy);
+}
+
+JNIEXPORT jstring JNICALL
+Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_benchGgmlMulMat(JNIEnv *env, jobject thiz,
+                                                                          jint n_threads) {
+    UNUSED(thiz);
+    const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads);
+    jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat);
+}
index c8a904bb040e9c8c32018a3cd44aedb86711a1f6..14b04d7a1a2f28a9bc6332de972519673ba90a2e 100644 (file)
@@ -4551,6 +4551,15 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
 //
 
 WHISPER_API int whisper_bench_memcpy(int n_threads) {
+    fputs(whisper_bench_memcpy_str(n_threads), stderr);
+    return 0;
+}
+
+WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
+    static std::string s;
+    s = "";
+    char strbuf[256];
+
     ggml_time_init();
 
     size_t n    = 50;
@@ -4580,7 +4589,8 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
         src[0] = rand();
     }
 
-    fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
+    snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
+    s += strbuf;
 
     // needed to prevent the compile from optimizing the memcpy away
     {
@@ -4588,16 +4598,26 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
 
         for (size_t i = 0; i < size; i++) sum += dst[i];
 
-        fprintf(stderr, "sum:    %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
+        snprintf(strbuf, sizeof(strbuf), "sum:    %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
+        s += strbuf;
     }
 
     free(src);
     free(dst);
 
-    return 0;
+    return s.c_str();
 }
 
 WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
+    fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr);
+    return 0;
+}
+
+WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
+    static std::string s;
+    s = "";
+    char strbuf[256];
+
     ggml_time_init();
 
     const int n_max = 128;
@@ -4673,11 +4693,12 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
             s = ((2.0*N*N*N*n)/tsum)*1e-9;
         }
 
-        fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
+        snprintf(strbuf, sizeof(strbuf), "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
             N, N, s_fp16, n_fp16, s_fp32, n_fp32);
+        s += strbuf;
     }
 
-    return 0;
+    return s.c_str();
 }
 
 // =================================================================================================
index 3984195dc8c1f93973cc8dd230a112d4ee7988a6..0a8270db941d0871023050aaeab8ee2922e90f52 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -462,7 +462,9 @@ extern "C" {
     // Temporary helpers needed for exposing ggml interface
 
     WHISPER_API int whisper_bench_memcpy(int n_threads);
+    WHISPER_API const char * whisper_bench_memcpy_str(int n_threads);
     WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
+    WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
 
 #ifdef __cplusplus
 }