]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync latest whisper.cpp
authorGeorgi Gerganov <redacted>
Sun, 8 Jan 2023 18:23:01 +0000 (20:23 +0200)
committerGeorgi Gerganov <redacted>
Sun, 8 Jan 2023 18:23:24 +0000 (20:23 +0200)
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h
src/ggml.c

index d387a77c8153fe9943b89b318166ec1b1e0e5b79..48e02923d017fd01c326046a2f83f3ac98ef7fd0 100644 (file)
@@ -478,7 +478,7 @@ int main(int argc, char ** argv) {
 
     // whisper init
 
-    struct whisper_context * ctx = whisper_init(params.model.c_str());
+    struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
 
     if (ctx == nullptr) {
         fprintf(stderr, "error: failed to initialize whisper context\n");
index e8d9f0c925a34c0738693542d1d04a0fa897e4df..a64505693f718ec55c5a3c438ec619d271865a20 100644 (file)
@@ -437,8 +437,8 @@ struct whisper_context {
 };
 
 template<typename T>
-static void read_safe(std::ifstream& fin, T& dest) {
-    fin.read((char*)& dest, sizeof(T));
+static void read_safe(whisper_model_loader * loader, T & dest) {
+    loader->read(loader->context, &dest, sizeof(T));
 }
 
 // load the model from a ggml file
@@ -452,24 +452,18 @@ static void read_safe(std::ifstream& fin, T& dest) {
 //
 // see the convert-pt-to-ggml.py script for details
 //
-static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
-    fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
+static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
+    fprintf(stderr, "%s: loading model\n", __func__);
 
     auto & model = wctx.model;
     auto & vocab = wctx.vocab;
 
-    auto fin = std::ifstream(fname, std::ios::binary);
-    if (!fin) {
-        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
-        return false;
-    }
-
     // verify magic
     {
         uint32_t magic;
-        read_safe(fin, magic);
+        read_safe(loader, magic);
         if (magic != 0x67676d6c) {
-            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
             return false;
         }
     }
@@ -478,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     {
         auto & hparams = model.hparams;
 
-        read_safe(fin, hparams.n_vocab);
-        read_safe(fin, hparams.n_audio_ctx);
-        read_safe(fin, hparams.n_audio_state);
-        read_safe(fin, hparams.n_audio_head);
-        read_safe(fin, hparams.n_audio_layer);
-        read_safe(fin, hparams.n_text_ctx);
-        read_safe(fin, hparams.n_text_state);
-        read_safe(fin, hparams.n_text_head);
-        read_safe(fin, hparams.n_text_layer);
-        read_safe(fin, hparams.n_mels);
-        read_safe(fin, hparams.f16);
+        read_safe(loader, hparams.n_vocab);
+        read_safe(loader, hparams.n_audio_ctx);
+        read_safe(loader, hparams.n_audio_state);
+        read_safe(loader, hparams.n_audio_head);
+        read_safe(loader, hparams.n_audio_layer);
+        read_safe(loader, hparams.n_text_ctx);
+        read_safe(loader, hparams.n_text_state);
+        read_safe(loader, hparams.n_text_head);
+        read_safe(loader, hparams.n_text_layer);
+        read_safe(loader, hparams.n_mels);
+        read_safe(loader, hparams.f16);
 
         assert(hparams.n_text_state == hparams.n_audio_state);
 
@@ -536,17 +530,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
     {
         auto & filters = wctx.model.filters;
 
-        read_safe(fin, filters.n_mel);
-        read_safe(fin, filters.n_fft);
+        read_safe(loader, filters.n_mel);
+        read_safe(loader, filters.n_fft);
 
         filters.data.resize(filters.n_mel * filters.n_fft);
-        fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
+        loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
     }
 
     // load vocab
     {
         int32_t n_vocab = 0;
-        read_safe(fin, n_vocab);
+        read_safe(loader, n_vocab);
 
         //if (n_vocab != model.hparams.n_vocab) {
         //    fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@@ -561,11 +555,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
         for (int i = 0; i < n_vocab; i++) {
             uint32_t len;
-            read_safe(fin, len);
+            read_safe(loader, len);
 
             if (len > 0) {
                 tmp.resize(len);
-                fin.read(&tmp[0], tmp.size()); // read to buffer
+                loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
                 word.assign(&tmp[0], tmp.size());
             } else {
                 // seems like we have an empty-string token in multi-language models (i = 50256)
@@ -1017,24 +1011,24 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
             int32_t length;
             int32_t ftype;
 
-            read_safe(fin, n_dims);
-            read_safe(fin, length);
-            read_safe(fin, ftype);
+            read_safe(loader, n_dims);
+            read_safe(loader, length);
+            read_safe(loader, ftype);
 
-            if (fin.eof()) {
+            if (loader->eof(loader->context)) {
                 break;
             }
 
             int32_t nelements = 1;
             int32_t ne[3] = { 1, 1, 1 };
             for (int i = 0; i < n_dims; ++i) {
-                read_safe(fin, ne[i]);
+                read_safe(loader, ne[i]);
                 nelements *= ne[i];
             }
 
             std::string name;
             std::vector<char> tmp(length); // create a buffer
-            fin.read(&tmp[0], tmp.size()); // read to buffer
+            loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
             name.assign(&tmp[0], tmp.size());
 
             if (model.tensors.find(name) == model.tensors.end()) {
@@ -1062,7 +1056,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
                 return false;
             }
 
-            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+            loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
 
             //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
             total_size += ggml_nbytes(tensor);
@@ -1079,8 +1073,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
         }
     }
 
-    fin.close();
-
     return true;
 }
 
@@ -1479,6 +1471,7 @@ static bool whisper_encode(
         }
 
         ggml_graph_compute(ctx0, &gf);
+        //ggml_graph_print(&gf);
     }
 
     ////////////////////////////////////////////////////////////////////////////
@@ -2240,7 +2233,74 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
 // interface implementation
 //
 
-struct whisper_context * whisper_init(const char * path_model) {
+struct whisper_context * whisper_init_from_file(const char * path_model) {
+    whisper_model_loader loader = {};
+
+    fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
+
+    auto fin = std::ifstream(path_model, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_model);
+        return nullptr;
+    }
+
+    loader.context = &fin;
+    loader.read = [](void * ctx, void * output, size_t read_size) {
+        std::ifstream * fin = (std::ifstream*)ctx;
+        fin->read((char *)output, read_size);
+        return read_size;
+    };
+
+    loader.eof = [](void * ctx) {
+        std::ifstream * fin = (std::ifstream*)ctx;
+        return fin->eof();
+    };
+
+    loader.close = [](void * ctx) {
+        std::ifstream * fin = (std::ifstream*)ctx;
+        fin->close();
+    };
+
+    return whisper_init(&loader);
+}
+
+struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
+    struct buf_context {
+        uint8_t* buffer;
+        size_t size;
+        size_t current_offset;
+    };
+
+    buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
+    whisper_model_loader loader = {};
+
+    fprintf(stderr, "%s: loading model from buffer\n", __func__);
+
+    loader.context = &ctx;
+
+    loader.read = [](void * ctx, void * output, size_t read_size) {
+        buf_context * buf = reinterpret_cast<buf_context *>(ctx);
+
+        size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
+
+        memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
+        buf->current_offset += size_to_copy;
+
+        return size_to_copy;
+    };
+
+    loader.eof = [](void * ctx) {
+        buf_context * buf = reinterpret_cast<buf_context *>(ctx);
+
+        return buf->current_offset >= buf->size;
+    };
+
+    loader.close = [](void * /*ctx*/) { };
+
+    return whisper_init(&loader);
+}
+
+struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
     ggml_time_init();
 
     whisper_context * ctx = new whisper_context;
@@ -2249,14 +2309,17 @@ struct whisper_context * whisper_init(const char * path_model) {
 
     ctx->t_start_us = t_start_us;
 
-    if (!whisper_model_load(path_model, *ctx)) {
-        fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
+    if (!whisper_model_load(loader, *ctx)) {
+        loader->close(loader->context);
+        fprintf(stderr, "%s: failed to load model\n", __func__);
         delete ctx;
         return nullptr;
     }
 
     ctx->t_load_us = ggml_time_us() - t_start_us;
 
+    loader->close(loader->context);
+
     return ctx;
 }
 
@@ -3326,7 +3389,7 @@ static int timestamp_to_sample(int64_t t, int n_samples) {
 }
 
 static int64_t sample_to_timestamp(int i_sample) {
-    return (100*i_sample)/WHISPER_SAMPLE_RATE;
+    return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
 }
 
 // a cost-function / heuristic that is high for text that takes longer to pronounce
index 8cb16caaeea9cc19b5efd7306b9a26cbc574493f..63f61af5114f2cb8864219b9b9615bea64126c2d 100644 (file)
@@ -1,6 +1,7 @@
 #ifndef WHISPER_H
 #define WHISPER_H
 
+#include <stddef.h>
 #include <stdint.h>
 #include <stdbool.h>
 
@@ -40,7 +41,7 @@ extern "C" {
     //
     //     ...
     //
-    //     struct whisper_context * ctx = whisper_init("/path/to/ggml-base.en.bin");
+    //     struct whisper_context * ctx = whisper_init_from_file("/path/to/ggml-base.en.bin");
     //
     //     if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
     //         fprintf(stderr, "failed to process audio\n");
@@ -84,9 +85,20 @@ extern "C" {
         float vlen;        // voice length of the token
     } whisper_token_data;
 
-    // Allocates all memory needed for the model and loads the model from the given file.
-    // Returns NULL on failure.
-    WHISPER_API struct whisper_context * whisper_init(const char * path_model);
+    typedef struct whisper_model_loader {
+        void * context;
+
+        size_t (*read)(void * ctx, void * output, size_t read_size);
+        bool    (*eof)(void * ctx);
+        void  (*close)(void * ctx);
+    } whisper_model_loader;
+
+    // Various functions for loading a ggml whisper model.
+    // Allocate (almost) all memory needed for the model.
+    // Return NULL on failure
+    WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model);
+    WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size);
+    WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader);
 
     // Frees all memory allocated by the model.
     WHISPER_API void whisper_free(struct whisper_context * ctx);
index eefdcdd7f65a93ec6d209c9e9cfa379f1a1a45c2..c59ee64af00c48455e49510ca1760d17e2b9169d 100644 (file)
@@ -84,7 +84,7 @@ typedef void* thread_ret_t;
 #define GGML_GELU_FP16
 
 #define GGML_SOFT_MAX_UNROLL 4
-#define GGML_VEC_DOT_UNROLL  4
+#define GGML_VEC_DOT_UNROLL  2
 
 #ifdef GGML_USE_ACCELERATE
 // uncomment to use vDSP for soft max computation
@@ -923,9 +923,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
     ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
 
-    const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv };
+    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
 
-    for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) {
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
         x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
     }
 
@@ -1109,8 +1109,8 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
     ggml_float sum = 0.0;
     for (int i = 0; i < n; ++i) {
         sum += x[i];
-        *s += sum;
     }
+    *s = sum;
 #else
     vDSP_sve(x, 1, s, n);
 #endif
@@ -3724,8 +3724,6 @@ static void ggml_compute_forward_sum_f32(
     assert(ggml_is_scalar(dst));
     assert(src0->nb[0] == sizeof(float));
 
-    *(float *) (dst->data) = 0.0f;
-
     const int ne00 = src0->ne[0];
     const int ne01 = src0->ne[1];
     const int ne02 = src0->ne[2];
@@ -3811,8 +3809,6 @@ static void ggml_compute_forward_mean_f32(
     for (int i03 = 0; i03 < ne03; i03++) {
         for (int i02 = 0; i02 < ne02; i02++) {
             for (int i01 = 0; i01 < ne01; i01++) {
-                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) = 0.0f;
-
                 ggml_vec_sum_f32(ne00,
                         (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
                         (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
@@ -4791,7 +4787,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
             }
         }
     } else {
-        // parallelize by src1 columns using ggml_vec_mad_f32
+        // parallelize by src1 columns using ggml_vec_mad_f16
         // each thread has its own work data
         // during FINALIZE we accumulate all work data into dst
 
@@ -6158,40 +6154,37 @@ static void ggml_compute_forward_flash_attn_f16(
             S[i] = -INFINITY;
         }
 
-        // looks like unrolling here does not help
-#if 1
-        for (int ic = 0; ic < nek1; ++ic) {
-            // k indices
-            const int ik3 = iq3;
-            const int ik2 = iq2;
-            const int ik1 = ic;
+        if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
+            for (int ic = 0; ic < nek1; ++ic) {
+                // k indices
+                const int ik3 = iq3;
+                const int ik2 = iq2;
+                const int ik1 = ic;
 
-            // S indices
-            const int i1 = ik1;
-
-            ggml_vec_dot_f16(neq0,
-                    S + i1,
-                    (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
-                    (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
-        }
-#else
-        GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0);
-
-        for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
-            // k indices
-            const int ik3 = iq3;
-            const int ik2 = iq2;
-            const int ik1 = ic;
+                // S indices
+                const int i1 = ik1;
 
-            // S indices
-            const int i1 = ik1;
-
-            ggml_vec_dot_f16_unroll(neq0, nbk1,
-                    S + i1,
-                                    ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
-                    (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+                ggml_vec_dot_f16(neq0,
+                        S + i1,
+                        (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                        (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+            }
+        } else {
+            for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
+                // k indices
+                const int ik3 = iq3;
+                const int ik2 = iq2;
+                const int ik1 = ic;
+
+                // S indices
+                const int i1 = ik1;
+
+                ggml_vec_dot_f16_unroll(neq0, nbk1,
+                        S + i1,
+                        ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                        (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+            }
         }
-#endif
 
         // scale
         ggml_vec_scale_f32(nek1, S, scale);
@@ -6261,18 +6254,30 @@ static void ggml_compute_forward_flash_attn_f16(
             S16[i] = GGML_FP32_TO_FP16(S[i]);
         }
 
-        GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0);
+        if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
+            for (int ic = 0; ic < nev1; ++ic) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
 
-        for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
-            // dst indices
-            const int i1 = iq1;
-            const int i2 = iq2;
-            const int i3 = iq3;
+                ggml_vec_dot_f16(nek1,
+                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                        (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                        S16);
+            }
+        } else {
+            for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
+                // dst indices
+                const int i1 = iq1;
+                const int i2 = iq2;
+                const int i3 = iq3;
 
-            ggml_vec_dot_f16_unroll(nek1, nbv1,
-                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
-                              ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
-                    S16);
+                ggml_vec_dot_f16_unroll(nek1, nbv1,
+                        (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                        ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                        S16);
+            }
         }
     }
 }