]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper.android : address ARM's big.LITTLE arch by checking cpu info (#1254)
authorDigipom <redacted>
Wed, 6 Sep 2023 15:32:30 +0000 (11:32 -0400)
committerGitHub <redacted>
Wed, 6 Sep 2023 15:32:30 +0000 (18:32 +0300)
Addresses https://github.com/ggerganov/whisper.cpp/issues/1248

examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt
examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/WhisperCpuConfig.kt [new file with mode: 0644]
examples/whisper.android/app/src/main/jni/whisper/jni.c

index a2b651c72ab5259684509c0d3b066cbe8547833d..b0d670372ea88a9fec642e9d32f82819735e9822 100644 (file)
@@ -18,7 +18,9 @@ class WhisperContext private constructor(private var ptr: Long) {
 
     suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
         require(ptr != 0L)
-        WhisperLib.fullTranscribe(ptr, data)
+        val numThreads = WhisperCpuConfig.preferredThreadCount
+        Log.d(LOG_TAG, "Selecting $numThreads threads")
+        WhisperLib.fullTranscribe(ptr, numThreads, data)
         val textCount = WhisperLib.getTextSegmentCount(ptr)
         return@withContext buildString {
             for (i in 0 until textCount) {
@@ -126,7 +128,7 @@ private class WhisperLib {
         external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
         external fun initContext(modelPath: String): Long
         external fun freeContext(contextPtr: Long)
-        external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
+        external fun fullTranscribe(contextPtr: Long, numThreads: Int, audioData: FloatArray)
         external fun getTextSegmentCount(contextPtr: Long): Int
         external fun getTextSegment(contextPtr: Long, index: Int): String
         external fun getSystemInfo(): String
diff --git a/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/WhisperCpuConfig.kt b/examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/WhisperCpuConfig.kt
new file mode 100644 (file)
index 0000000..5fa9a4e
--- /dev/null
@@ -0,0 +1,73 @@
+package com.whispercppdemo.whisper
+
+import android.util.Log
+import java.io.BufferedReader
+import java.io.FileReader
+
+object WhisperCpuConfig {
+    val preferredThreadCount: Int
+        // Always use at least 2 threads:
+        get() = CpuInfo.getHighPerfCpuCount().coerceAtLeast(2)
+}
+
+private class CpuInfo(private val lines: List<String>) {
+    private fun getHighPerfCpuCount(): Int = try {
+        getHighPerfCpuCountByFrequencies()
+    } catch (e: Exception) {
+        Log.d(LOG_TAG, "Couldn't read CPU frequencies", e)
+        getHighPerfCpuCountByVariant()
+    }
+
+    private fun getHighPerfCpuCountByFrequencies(): Int =
+        getCpuValues(property = "processor") { getMaxCpuFrequency(it.toInt()) }
+            .also { Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): ${it.binnedValues()}") }
+            .countDroppingMin()
+
+    private fun getHighPerfCpuCountByVariant(): Int =
+        getCpuValues(property = "CPU variant") { it.substringAfter("0x").toInt(radix = 16) }
+            .also { Log.d(LOG_TAG, "Binned cpu variants (variant, count): ${it.binnedValues()}") }
+            .countKeepingMin()
+
+    private fun List<Int>.binnedValues() = groupingBy { it }.eachCount()
+
+    private fun getCpuValues(property: String, mapper: (String) -> Int) = lines
+        .asSequence()
+        .filter { it.startsWith(property) }
+        .map { mapper(it.substringAfter(':').trim()) }
+        .sorted()
+        .toList()
+
+
+    private fun List<Int>.countDroppingMin(): Int {
+        val min = min()
+        return count { it > min }
+    }
+
+    private fun List<Int>.countKeepingMin(): Int {
+        val min = min()
+        return count { it == min }
+    }
+
+    companion object {
+        private const val LOG_TAG = "WhisperCpuConfig"
+
+        fun getHighPerfCpuCount(): Int = try {
+            readCpuInfo().getHighPerfCpuCount()
+        } catch (e: Exception) {
+            Log.d(LOG_TAG, "Couldn't read CPU info", e)
+            // Our best guess -- just return the # of CPUs minus 4.
+            (Runtime.getRuntime().availableProcessors() - 4).coerceAtLeast(0)
+        }
+
+        private fun readCpuInfo() = CpuInfo(
+            BufferedReader(FileReader("/proc/cpuinfo"))
+                .useLines { it.toList() }
+        )
+
+        private fun getMaxCpuFrequency(cpuIndex: Int): Int {
+            val path = "/sys/devices/system/cpu/cpu${cpuIndex}/cpufreq/cpuinfo_max_freq"
+            val maxFreq = BufferedReader(FileReader(path)).use { it.readLine() }
+            return maxFreq.toInt()
+        }
+    }
+}
\ No newline at end of file
index 82dfd77808e94678d575a7489e1b10fa5e4bdf87..c437d0990f1a6fca75982e41788288a71e2d0619 100644 (file)
@@ -163,16 +163,12 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_freeContext(
 
 JNIEXPORT void JNICALL
 Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
-        JNIEnv *env, jobject thiz, jlong context_ptr, jfloatArray audio_data) {
+        JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) {
     UNUSED(thiz);
     struct whisper_context *context = (struct whisper_context *) context_ptr;
     jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
     const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
 
-    // Leave 2 processors free (i.e. the high-efficiency cores).
-    int max_threads = max(1, min(8, get_nprocs() - 2));
-    LOGI("Selecting %d threads", max_threads);
-
     // The below adapted from the Objective-C iOS sample
     struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
     params.print_realtime = true;
@@ -181,7 +177,7 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
     params.print_special = false;
     params.translate = false;
     params.language = "en";
-    params.n_threads = max_threads;
+    params.n_threads = num_threads;
     params.offset_ms = 0;
     params.no_context = true;
     params.single_segment = false;