suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
require(ptr != 0L)
- WhisperLib.fullTranscribe(ptr, data)
+ val numThreads = WhisperCpuConfig.preferredThreadCount
+ Log.d(LOG_TAG, "Selecting $numThreads threads")
+ WhisperLib.fullTranscribe(ptr, numThreads, data)
val textCount = WhisperLib.getTextSegmentCount(ptr)
return@withContext buildString {
for (i in 0 until textCount) {
external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
external fun initContext(modelPath: String): Long
external fun freeContext(contextPtr: Long)
- external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
+ external fun fullTranscribe(contextPtr: Long, numThreads: Int, audioData: FloatArray)
external fun getTextSegmentCount(contextPtr: Long): Int
external fun getTextSegment(contextPtr: Long, index: Int): String
external fun getSystemInfo(): String
--- /dev/null
+package com.whispercppdemo.whisper
+
+import android.util.Log
+import java.io.BufferedReader
+import java.io.FileReader
+
+object WhisperCpuConfig {
+ val preferredThreadCount: Int
+ // Always use at least 2 threads:
+ get() = CpuInfo.getHighPerfCpuCount().coerceAtLeast(2)
+}
+
+private class CpuInfo(private val lines: List<String>) {
+ private fun getHighPerfCpuCount(): Int = try {
+ getHighPerfCpuCountByFrequencies()
+ } catch (e: Exception) {
+ Log.d(LOG_TAG, "Couldn't read CPU frequencies", e)
+ getHighPerfCpuCountByVariant()
+ }
+
+ private fun getHighPerfCpuCountByFrequencies(): Int =
+ getCpuValues(property = "processor") { getMaxCpuFrequency(it.toInt()) }
+ .also { Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): ${it.binnedValues()}") }
+ .countDroppingMin()
+
+ private fun getHighPerfCpuCountByVariant(): Int =
+ getCpuValues(property = "CPU variant") { it.substringAfter("0x").toInt(radix = 16) }
+ .also { Log.d(LOG_TAG, "Binned cpu variants (variant, count): ${it.binnedValues()}") }
+ .countKeepingMin()
+
+ private fun List<Int>.binnedValues() = groupingBy { it }.eachCount()
+
+ private fun getCpuValues(property: String, mapper: (String) -> Int) = lines
+ .asSequence()
+ .filter { it.startsWith(property) }
+ .map { mapper(it.substringAfter(':').trim()) }
+ .sorted()
+ .toList()
+
+
+ private fun List<Int>.countDroppingMin(): Int {
+ val min = min()
+ return count { it > min }
+ }
+
+ private fun List<Int>.countKeepingMin(): Int {
+ val min = min()
+ return count { it == min }
+ }
+
+ companion object {
+ private const val LOG_TAG = "WhisperCpuConfig"
+
+ fun getHighPerfCpuCount(): Int = try {
+ readCpuInfo().getHighPerfCpuCount()
+ } catch (e: Exception) {
+ Log.d(LOG_TAG, "Couldn't read CPU info", e)
+ // Our best guess -- just return the # of CPUs minus 4.
+ (Runtime.getRuntime().availableProcessors() - 4).coerceAtLeast(0)
+ }
+
+ private fun readCpuInfo() = CpuInfo(
+ BufferedReader(FileReader("/proc/cpuinfo"))
+ .useLines { it.toList() }
+ )
+
+ private fun getMaxCpuFrequency(cpuIndex: Int): Int {
+ val path = "/sys/devices/system/cpu/cpu${cpuIndex}/cpufreq/cpuinfo_max_freq"
+ val maxFreq = BufferedReader(FileReader(path)).use { it.readLine() }
+ return maxFreq.toInt()
+ }
+ }
+}
\ No newline at end of file
JNIEXPORT void JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
- JNIEnv *env, jobject thiz, jlong context_ptr, jfloatArray audio_data) {
+ JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) {
UNUSED(thiz);
struct whisper_context *context = (struct whisper_context *) context_ptr;
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
- // Leave 2 processors free (i.e. the high-efficiency cores).
- int max_threads = max(1, min(8, get_nprocs() - 2));
- LOGI("Selecting %d threads", max_threads);
-
// The below adapted from the Objective-C iOS sample
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
params.print_realtime = true;
params.print_special = false;
params.translate = false;
params.language = "en";
- params.n_threads = max_threads;
+ params.n_threads = num_threads;
params.offset_ms = 0;
params.no_context = true;
params.single_segment = false;