]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : add loader class to allow loading from buffer and others (#353)
authorSyahmi Azhar <redacted>
Sun, 8 Jan 2023 11:03:33 +0000 (19:03 +0800)
committerGitHub <redacted>
Sun, 8 Jan 2023 11:03:33 +0000 (13:03 +0200)
* whisper : add loader to allow loading from other than file

* whisper : rename whisper_init to whisper_init_from_file

* whisper : add whisper_init_from_buffer

* android : Delete local.properties

* android : load models directly from assets

* whisper : adding <stddef.h> needed for size_t + code style

Co-authored-by: Georgi Gerganov <redacted>
20 files changed:
bindings/go/whisper.go
bindings/javascript/emscripten.cpp
examples/bench.wasm/emscripten.cpp
examples/bench/bench.cpp
examples/command.wasm/emscripten.cpp
examples/command/command.cpp
examples/main/main.cpp
examples/stream.wasm/emscripten.cpp
examples/stream/stream.cpp
examples/talk.wasm/emscripten.cpp
examples/talk/talk.cpp
examples/whisper.android/app/src/main/java/com/whispercppdemo/ui/main/MainScreenViewModel.kt
examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt
examples/whisper.android/app/src/main/jni/whisper/jni.c
examples/whisper.android/local.properties [deleted file]
examples/whisper.objc/whisper.objc/ViewController.m
examples/whisper.swiftui/whisper.cpp.swift/LibWhisper.swift
examples/whisper.wasm/emscripten.cpp
whisper.cpp
whisper.h

index 9381879c25dfb46225d6249779ad011197386fe2..8d12fed10be79d586752d58a30492a0d6dffbbb3 100644 (file)
@@ -91,7 +91,7 @@ var (
 func Whisper_init(path string) *Context {
        cPath := C.CString(path)
        defer C.free(unsafe.Pointer(cPath))
-       if ctx := C.whisper_init(cPath); ctx != nil {
+       if ctx := C.whisper_init_from_file(cPath); ctx != nil {
                return (*Context)(ctx)
        } else {
                return nil
index bda963017b3c443561f4ed8df719b31646d93f69..789ad8b51f8cdd58ec4237ee24442d1c33b69e1a 100644 (file)
@@ -20,7 +20,7 @@ struct whisper_context * g_context;
 EMSCRIPTEN_BINDINGS(whisper) {
     emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
         if (g_context == nullptr) {
-            g_context = whisper_init(path_model.c_str());
+            g_context = whisper_init_from_file(path_model.c_str());
             if (g_context != nullptr) {
                 return true;
             } else {
index 2e63315d671685cb845fed0dae9119d3d5d55eb5..959414799cc543ee89f483ac99e8c3f5e8778a45 100644 (file)
@@ -52,7 +52,7 @@ EMSCRIPTEN_BINDINGS(bench) {
     emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
         for (size_t i = 0; i < g_contexts.size(); ++i) {
             if (g_contexts[i] == nullptr) {
-                g_contexts[i] = whisper_init(path_model.c_str());
+                g_contexts[i] = whisper_init_from_file(path_model.c_str());
                 if (g_contexts[i] != nullptr) {
                     if (g_worker.joinable()) {
                         g_worker.join();
index 3ab5077b87af3886e8f9d7400d3e2208b743cfa9..2fd2423f5fa1517916ef32866c9c2ea761ba871a 100644 (file)
@@ -53,7 +53,7 @@ int main(int argc, char ** argv) {
 
     // whisper init
 
-    struct whisper_context * ctx = whisper_init(params.model.c_str());
+    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
 
     {
         fprintf(stderr, "\n");
index d4bbb212327f0dbe01e3e0ff609af7db8f3b750c..f2ba81e96001ffe1e0eae57f1556c0ed74d5faf4 100644 (file)
@@ -324,7 +324,7 @@ EMSCRIPTEN_BINDINGS(command) {
     emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
         for (size_t i = 0; i < g_contexts.size(); ++i) {
             if (g_contexts[i] == nullptr) {
-                g_contexts[i] = whisper_init(path_model.c_str());
+                g_contexts[i] = whisper_init_from_file(path_model.c_str());
                 if (g_contexts[i] != nullptr) {
                     g_running = true;
                     if (g_worker.joinable()) {
index 4558a67dae9b575e0ab9f1e05d124866d4c3305e..3dae3a5e31c7b2b6f358ebe9475abdafe2c1f1b3 100644 (file)
@@ -931,7 +931,7 @@ int main(int argc, char ** argv) {
 
     // whisper init
 
-    struct whisper_context * ctx = whisper_init(params.model.c_str());
+    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
 
     // print some info about the processing
     {
index d387a77c8153fe9943b89b318166ec1b1e0e5b79..48e02923d017fd01c326046a2f83f3ac98ef7fd0 100644 (file)
@@ -478,7 +478,7 @@ int main(int argc, char ** argv) {
 
     // whisper init
 
-    struct whisper_context * ctx = whisper_init(params.model.c_str());
+    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
 
     if (ctx == nullptr) {
         fprintf(stderr, "error: failed to initialize whisper context\n");
index b75eee365aaab66004be9f9476f957901474b767..e4cdf639a406f1488a27efbde47c3b3025cbef5d 100644 (file)
@@ -129,7 +129,7 @@ EMSCRIPTEN_BINDINGS(stream) {
     emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
         for (size_t i = 0; i < g_contexts.size(); ++i) {
             if (g_contexts[i] == nullptr) {
-                g_contexts[i] = whisper_init(path_model.c_str());
+                g_contexts[i] = whisper_init_from_file(path_model.c_str());
                 if (g_contexts[i] != nullptr) {
                     g_running = true;
                     if (g_worker.joinable()) {
index 9caa6148d890df7e22e1d7719f419d9a9b9454a2..c7aa87178f9ff48127b999609b581a8317916b80 100644 (file)
@@ -456,7 +456,7 @@ int main(int argc, char ** argv) {
         exit(0);
     }
 
-    struct whisper_context * ctx = whisper_init(params.model.c_str());
+    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
 
     std::vector<float> pcmf32    (n_samples_30s, 0.0f);
     std::vector<float> pcmf32_old(n_samples_30s, 0.0f);
index c82f4696d8ebd156ca5f74f8ca8a797ea624d99b..1ea970295ac066f3aa06b874e4b2ad808b6861eb 100644 (file)
@@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) {
     emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
         for (size_t i = 0; i < g_contexts.size(); ++i) {
             if (g_contexts[i] == nullptr) {
-                g_contexts[i] = whisper_init(path_model.c_str());
+                g_contexts[i] = whisper_init_from_file(path_model.c_str());
                 if (g_contexts[i] != nullptr) {
                     g_running = true;
                     if (g_worker.joinable()) {
index ec57a95cd70119e37d1ef86ab0f9633b5da89275..55cd46a9bbfcb36bb50b67bef31ca6308d4465c8 100644 (file)
@@ -498,7 +498,7 @@ int main(int argc, char ** argv) {
 
     // whisper init
 
-    struct whisper_context * ctx_wsp = whisper_init(params.model_wsp.c_str());
+    struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
 
     // gpt init
 
index 866444092e0b775a03d1caf12e2a36be18ed6186..d7417482a1894ed37a673c1c3e5722e58af66e8b 100644 (file)
@@ -64,16 +64,22 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
     private suspend fun copyAssets() = withContext(Dispatchers.IO) {
         modelsPath.mkdirs()
         samplesPath.mkdirs()
-        application.copyData("models", modelsPath, ::printMessage)
+        //application.copyData("models", modelsPath, ::printMessage)
         application.copyData("samples", samplesPath, ::printMessage)
         printMessage("All data copied to working directory.\n")
     }
 
     private suspend fun loadBaseModel() = withContext(Dispatchers.IO) {
         printMessage("Loading model...\n")
-        val firstModel = modelsPath.listFiles()!!.first()
-        whisperContext = WhisperContext.createContext(firstModel.absolutePath)
-        printMessage("Loaded model ${firstModel.name}.\n")
+        val models = application.assets.list("models/")
+        if (models != null) {
+            val inputstream = application.assets.open("models/" + models[0])
+            whisperContext = WhisperContext.createContextFromInputStream(inputstream)
+            printMessage("Loaded model ${models[0]}.\n")
+        }
+
+        //val firstModel = modelsPath.listFiles()!!.first()
+        //whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
     }
 
     fun transcribeSample() = viewModelScope.launch {
index a6dfdcca338e21e4fd0fca8c058ae7c82da2101a..edd041a78302c144ff5943c89ce0b5bb15eb2b26 100644 (file)
@@ -4,6 +4,7 @@ import android.os.Build
 import android.util.Log
 import kotlinx.coroutines.*
 import java.io.File
+import java.io.InputStream
 import java.util.concurrent.Executors
 
 private const val LOG_TAG = "LibWhisper"
@@ -39,13 +40,22 @@ class WhisperContext private constructor(private var ptr: Long) {
     }
 
     companion object {
-        fun createContext(filePath: String): WhisperContext {
+        fun createContextFromFile(filePath: String): WhisperContext {
             val ptr = WhisperLib.initContext(filePath)
             if (ptr == 0L) {
                 throw java.lang.RuntimeException("Couldn't create context with path $filePath")
             }
             return WhisperContext(ptr)
         }
+
+        fun createContextFromInputStream(stream: InputStream): WhisperContext {
+            val ptr = WhisperLib.initContextFromInputStream(stream)
+
+            if (ptr == 0L) {
+                throw java.lang.RuntimeException("Couldn't create context from input stream")
+            }
+            return WhisperContext(ptr)
+        }
     }
 }
 
@@ -76,6 +86,7 @@ private class WhisperLib {
         }
 
         // JNI methods
+        external fun initContextFromInputStream(inputStream: InputStream): Long
         external fun initContext(modelPath: String): Long
         external fun freeContext(contextPtr: Long)
         external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
index e3fe69572c43ace1d50db041785b339dcb96efd4..0fd2897c73e0ba785ec0d154a4ed30f95d1f23da 100644 (file)
@@ -2,6 +2,7 @@
 #include <android/log.h>
 #include <stdlib.h>
 #include <sys/sysinfo.h>
+#include <string.h>
 #include "whisper.h"
 
 #define UNUSED(x) (void)(x)
@@ -17,13 +18,86 @@ static inline int max(int a, int b) {
     return (a > b) ? a : b;
 }
 
+struct input_stream_context {
+    size_t offset;
+    JNIEnv * env;
+    jobject thiz;
+    jobject input_stream;
+
+    jmethodID mid_available;
+    jmethodID mid_read;
+};
+
+size_t inputStreamRead(void * ctx, void * output, size_t read_size) {
+    struct input_stream_context* is = (struct input_stream_context*)ctx;
+
+    jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
+    jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size;
+
+    jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy);
+
+    jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy);
+
+    if (size_to_copy != read_size || size_to_copy != n_read) {
+        LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read);
+    }
+
+    jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL);
+    memcpy(output, byte_array_elements, size_to_copy);
+    (*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT);
+
+    (*is->env)->DeleteLocalRef(is->env, byte_array);
+
+    is->offset += size_to_copy;
+
+    return size_to_copy;
+}
+bool inputStreamEof(void * ctx) {
+    struct input_stream_context* is = (struct input_stream_context*)ctx;
+
+    jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
+    return result <= 0;
+}
+void inputStreamClose(void * ctx) {
+
+}
+
+JNIEXPORT jlong JNICALL
+Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromInputStream(
+        JNIEnv *env, jobject thiz, jobject input_stream) {
+    UNUSED(thiz);
+
+    struct whisper_context *context = NULL;
+    struct whisper_model_loader loader = {};
+    struct input_stream_context inp_ctx = {};
+
+    inp_ctx.offset = 0;
+    inp_ctx.env = env;
+    inp_ctx.thiz = thiz;
+    inp_ctx.input_stream = input_stream;
+
+    jclass cls = (*env)->GetObjectClass(env, input_stream);
+    inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I");
+    inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I");
+
+    loader.context = &inp_ctx;
+    loader.read = inputStreamRead;
+    loader.eof = inputStreamEof;
+    loader.close = inputStreamClose;
+
+    loader.eof(loader.context);
+
+    context = whisper_init(&loader);
+    return (jlong) context;
+}
+
 JNIEXPORT jlong JNICALL
 Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext(
         JNIEnv *env, jobject thiz, jstring model_path_str) {
     UNUSED(thiz);
     struct whisper_context *context = NULL;
     const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
-    context = whisper_init(model_path_chars);
+    context = whisper_init_from_file(model_path_chars);
     (*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
     return (jlong) context;
 }
diff --git a/examples/whisper.android/local.properties b/examples/whisper.android/local.properties
deleted file mode 100644 (file)
index cd5e215..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-## This file is automatically generated by Android Studio.
-# Do not modify this file -- YOUR CHANGES WILL BE ERASED!
-#
-# This file should *NOT* be checked into Version Control Systems,
-# as it contains information specific to your local configuration.
-#
-# Location of the SDK. This is only used by Gradle.
-# For customization when using a Version Control System, please read the
-# header note.
-sdk.dir=/Users/kevin/Library/Android/sdk
\ No newline at end of file
index d6aef36924527f64101f9eda89f3e8fa7ab65755..8a1e876c395c6451126851d40166682abba6f9ce 100644 (file)
@@ -61,7 +61,7 @@ void AudioInputCallback(void * inUserData,
         NSLog(@"Loading model from %@", modelPath);
 
         // create ggml context
-        stateInp.ctx = whisper_init([modelPath UTF8String]);
+        stateInp.ctx = whisper_init_from_file([modelPath UTF8String]);
 
         // check if the model was loaded successfully
         if (stateInp.ctx == NULL) {
index 9adfb425e09959c8332af3b2e91a7dd1a90147cf..e9645b34f7466c3c0275e0e5f24a9ffdd22ad004 100644 (file)
@@ -55,7 +55,7 @@ actor WhisperContext {
     }
     
     static func createContext(path: String) throws -> WhisperContext {
-        let context = whisper_init(path)
+        let context = whisper_init_from_file(path)
         if let context {
             return WhisperContext(context: context)
         } else {
index 33ae0a1b73bb8ffa2f5617b3f77a04213514b5ce..f92d814e0fcfd9c5f9b5480fc59c5ddc6d7d1d1e 100644 (file)
@@ -18,7 +18,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
 
         for (size_t i = 0; i < g_contexts.size(); ++i) {
             if (g_contexts[i] == nullptr) {
-                g_contexts[i] = whisper_init(path_model.c_str());
+                g_contexts[i] = whisper_init_from_file(path_model.c_str());
                 if (g_contexts[i] != nullptr) {
                     return i + 1;
                 } else {
index e8d9f0c925a34c0738693542d1d04a0fa897e4df..433b735502918e1e77f3f13661a1111e93d7ade7 100644 (file)
@@ -437,8 +437,8 @@ struct whisper_context {
 };
 
 template<typename T>
-static void read_safe(std::ifstream& fin, T& dest) {
-    fin.read((char*)& dest, sizeof(T));
+static void read_safe(whisper_model_loader * loader, T & dest) {
+    loader->read(loader->context, &dest, sizeof(T));
 }
 
 // load the model from a ggml file
@@ -452,24 +452,18 @@ static void read_safe(std::ifstream& fin, T& dest) {
 //
 // see the convert-pt-to-ggml.py script for details
 //
-static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
-    fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
+static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
+    fprintf(stderr, "%s: loading model\n", __func__);
 
     auto & model = wctx.model;
     auto & vocab = wctx.vocab;
 
-    auto fin = std::ifstream(fname, std::ios::binary);
-    if (!fin) {
-        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
-        return false;
-    }
-
     // verify magic
     {
         uint32_t magic;
-        read_safe(fin, magic);
+        read_safe(loader, magic);
         if (magic != 0x67676d6c) {
-            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
             return false;
         }
     }
@@ -478,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     {
         auto & hparams = model.hparams;
 
-        read_safe(fin, hparams.n_vocab);
-        read_safe(fin, hparams.n_audio_ctx);
-        read_safe(fin, hparams.n_audio_state);
-        read_safe(fin, hparams.n_audio_head);
-        read_safe(fin, hparams.n_audio_layer);
-        read_safe(fin, hparams.n_text_ctx);
-        read_safe(fin, hparams.n_text_state);
-        read_safe(fin, hparams.n_text_head);
-        read_safe(fin, hparams.n_text_layer);
-        read_safe(fin, hparams.n_mels);
-        read_safe(fin, hparams.f16);
+        read_safe(loader, hparams.n_vocab);
+        read_safe(loader, hparams.n_audio_ctx);
+        read_safe(loader, hparams.n_audio_state);
+        read_safe(loader, hparams.n_audio_head);
+        read_safe(loader, hparams.n_audio_layer);
+        read_safe(loader, hparams.n_text_ctx);
+        read_safe(loader, hparams.n_text_state);
+        read_safe(loader, hparams.n_text_head);
+        read_safe(loader, hparams.n_text_layer);
+        read_safe(loader, hparams.n_mels);
+        read_safe(loader, hparams.f16);
 
         assert(hparams.n_text_state == hparams.n_audio_state);
 
@@ -536,17 +530,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     {
         auto & filters = wctx.model.filters;
 
-        read_safe(fin, filters.n_mel);
-        read_safe(fin, filters.n_fft);
+        read_safe(loader, filters.n_mel);
+        read_safe(loader, filters.n_fft);
 
         filters.data.resize(filters.n_mel * filters.n_fft);
-        fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
+        loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
     }
 
     // load vocab
     {
         int32_t n_vocab = 0;
-        read_safe(fin, n_vocab);
+        read_safe(loader, n_vocab);
 
         //if (n_vocab != model.hparams.n_vocab) {
         //    fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@@ -561,11 +555,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
         for (int i = 0; i < n_vocab; i++) {
             uint32_t len;
-            read_safe(fin, len);
+            read_safe(loader, len);
 
             if (len > 0) {
                 tmp.resize(len);
-                fin.read(&tmp[0], tmp.size()); // read to buffer
+                loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
                 word.assign(&tmp[0], tmp.size());
             } else {
                 // seems like we have an empty-string token in multi-language models (i = 50256)
@@ -1017,24 +1011,24 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
             int32_t length;
             int32_t ftype;
 
-            read_safe(fin, n_dims);
-            read_safe(fin, length);
-            read_safe(fin, ftype);
+            read_safe(loader, n_dims);
+            read_safe(loader, length);
+            read_safe(loader, ftype);
 
-            if (fin.eof()) {
+            if (loader->eof(loader->context)) {
                 break;
             }
 
             int32_t nelements = 1;
             int32_t ne[3] = { 1, 1, 1 };
             for (int i = 0; i < n_dims; ++i) {
-                read_safe(fin, ne[i]);
+                read_safe(loader, ne[i]);
                 nelements *= ne[i];
             }
 
             std::string name;
             std::vector<char> tmp(length); // create a buffer
-            fin.read(&tmp[0], tmp.size()); // read to buffer
+            loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
             name.assign(&tmp[0], tmp.size());
 
             if (model.tensors.find(name) == model.tensors.end()) {
@@ -1062,7 +1056,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
                 return false;
             }
 
-            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+            loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
 
             //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
             total_size += ggml_nbytes(tensor);
@@ -1079,8 +1073,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
         }
     }
 
-    fin.close();
-
     return true;
 }
 
@@ -2240,7 +2232,74 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
 // interface implementation
 //
 
-struct whisper_context * whisper_init(const char * path_model) {
+struct whisper_context * whisper_init_from_file(const char * path_model) {
+    whisper_model_loader loader = {};
+
+    fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
+
+    auto fin = std::ifstream(path_model, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
+        return nullptr;
+    }
+
+    loader.context = &fin;
+    loader.read = [](void * ctx, void * output, size_t read_size) {
+        std::ifstream * fin = (std::ifstream*)ctx;
+        fin->read((char *)output, read_size);
+        return read_size;
+    };
+
+    loader.eof = [](void * ctx) {
+        std::ifstream * fin = (std::ifstream*)ctx;
+        return fin->eof();
+    };
+
+    loader.close = [](void * ctx) {
+        std::ifstream * fin = (std::ifstream*)ctx;
+        fin->close();
+    };
+
+    return whisper_init(&loader);
+}
+
+struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
+    struct buf_context {
+        uint8_t* buffer;
+        size_t size;
+        size_t current_offset;
+    };
+
+    buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
+    whisper_model_loader loader = {};
+
+    fprintf(stderr, "%s: loading model from buffer\n", __func__);
+
+    loader.context = &ctx;
+
+    loader.read = [](void * ctx, void * output, size_t read_size) {
+        buf_context * buf = reinterpret_cast<buf_context *>(ctx);
+
+        size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
+
+        memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
+        buf->current_offset += size_to_copy;
+
+        return size_to_copy;
+    };
+
+    loader.eof = [](void * ctx) {
+        buf_context * buf = reinterpret_cast<buf_context *>(ctx);
+
+        return buf->current_offset >= buf->size;
+    };
+
+    loader.close = [](void * /*ctx*/) { };
+
+    return whisper_init(&loader);
+}
+
+struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
     ggml_time_init();
 
     whisper_context * ctx = new whisper_context;
@@ -2249,14 +2308,17 @@ struct whisper_context * whisper_init(const char * path_model) {
 
     ctx->t_start_us = t_start_us;
 
-    if (!whisper_model_load(path_model, *ctx)) {
-        fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
+    if (!whisper_model_load(loader, *ctx)) {
+        loader->close(loader->context);
+        fprintf(stderr, "%s: failed to load model\n", __func__);
         delete ctx;
         return nullptr;
     }
 
     ctx->t_load_us = ggml_time_us() - t_start_us;
 
+    loader->close(loader->context);
+
     return ctx;
 }
 
index 8cb16caaeea9cc19b5efd7306b9a26cbc574493f..582138f9ddb7fbfe95fdb57fcb3a11d601007731 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -1,6 +1,7 @@
 #ifndef WHISPER_H
 #define WHISPER_H
 
+#include <stddef.h>
 #include <stdint.h>
 #include <stdbool.h>
 
@@ -40,7 +41,7 @@ extern "C" {
     //
     //     ...
     //
-    //     struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin");
+    //     struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin");
     //
     //     if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
     //         fprintf(stderr, "failed to process audio\n");
@@ -84,9 +85,20 @@ extern "C" {
         float vlen;        // voice length of the token
     } whisper_token_data;
 
-    // Allocates all memory needed for the model and loads the model from the given file.
-    // Returns NULL on failure.
-    WHISPER_API struct whisper_context * whisper_init(const char * path_model);
+    typedef struct whisper_model_loader {
+        void * context;
+
+        size_t (*read)(void * ctx, void * output, size_t read_size);
+        bool    (*eof)(void * ctx);
+        void  (*close)(void * ctx);
+    } whisper_model_loader;
+
+    // Various function to load a ggml whisper model.
+    // Allocates (almost) all memory needed for the model.
+    // Return NULL on failure
+    WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
+    WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
+    WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
 
     // Frees all memory allocated by the model.
     WHISPER_API void whisper_free(struct whisper_context * ctx);