]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
stream : partial encoder experiments
authorGeorgi Gerganov <redacted>
Fri, 11 Nov 2022 20:33:10 +0000 (22:33 +0200)
committerGeorgi Gerganov <redacted>
Sun, 20 Nov 2022 19:22:41 +0000 (21:22 +0200)
examples/stream/stream.cpp
whisper.cpp
whisper.h

index 718c8151d39d4e413e454dd8b6cfb42765591932..3c2f86126d992a8a1a550a5b5d6d59ad6e2caa8f 100644 (file)
@@ -221,6 +221,7 @@ int main(int argc, char ** argv) {
     const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE;
     const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE;
     const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
+    const int n_samples_keep = 0.2*WHISPER_SAMPLE_RATE;
 
     std::vector<float> pcmf32(n_samples_30s, 0.0f);
     std::vector<float> pcmf32_old;
@@ -303,7 +304,7 @@ int main(int argc, char ** argv) {
         //const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new));
 
         // take up to params.length_ms audio from previous iteration
-        const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_len - n_samples_new));
+        const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
 
         //printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
 
@@ -379,7 +380,8 @@ int main(int argc, char ** argv) {
             if ((n_iter % n_new_line) == 0) {
                 printf("\n");
 
-                pcmf32_old.clear();
+                // keep part of the audio for next iteration to try to mitigate word boundary issues
+                pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
             }
         }
     }
index a8b9e7149cd81e6478e3151a3e1b394f31e277d9..7c4a1d4c4b65dafb14728940faad9d7739c414a7 100644 (file)
@@ -613,7 +613,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
         const int n_audio_state = hparams.n_audio_state;
         const int n_audio_layer = hparams.n_audio_layer;
 
-        const int n_text_ctx = hparams.n_text_ctx;
+        const int n_text_ctx   = hparams.n_text_ctx;
         const int n_text_state = hparams.n_text_state;
         const int n_text_layer = hparams.n_text_layer;
 
@@ -748,7 +748,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
         const int n_audio_state = hparams.n_audio_state;
         const int n_audio_layer = hparams.n_audio_layer;
 
-        const int n_text_ctx = hparams.n_text_ctx;
+        const int n_text_ctx   = hparams.n_text_ctx;
         const int n_text_state = hparams.n_text_state;
         const int n_text_layer = hparams.n_text_layer;
 
@@ -967,13 +967,16 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
 
         // key/value memory for the cross-attention layer
         {
-            const int n_audio_ctx   = hparams.n_audio_ctx;
+            const int n_audio_ctx = hparams.n_audio_ctx;
 
             const int n_mem      = n_text_layer*n_audio_ctx;
             const int n_elements = n_text_state*n_mem;
 
             model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
             model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+
+            //memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
+            //memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
         }
 
         const size_t memory_size =
@@ -1076,13 +1079,11 @@ static bool whisper_encode(
     const auto & mel_inp = wctx.mel;
     const auto & hparams = model.hparams;
 
-    const int n_ctx   = hparams.n_audio_ctx;
+    const int n_ctx   = WHISPER_EXPERIMENT_AUDIO_CTX;
     const int n_state = hparams.n_audio_state;
     const int n_head  = hparams.n_audio_head;
     const int n_layer = hparams.n_audio_layer;
 
-    const int N = n_ctx;
-
     const int n_mels = hparams.n_mels;
     assert(mel_inp.n_mel == n_mels);
 
@@ -1132,7 +1133,24 @@ static bool whisper_encode(
         cur = ggml_gelu(ctx0, cur);
     }
 
-    cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
+    //static int iter = -1;
+    //const int n_iter = 1500/n_ctx;
+
+    //iter = (iter + 1) % n_iter;
+
+    //if (iter == 0) {
+    //    memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
+    //    memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
+    //}
+
+    static int iter = 0;
+
+    const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
+    const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
+
+    struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
+
+    cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur));
 
     struct ggml_tensor * inpL = cur;
 
@@ -1198,14 +1216,14 @@ static bool whisper_encode(
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Qcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Kcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * V =
@@ -1213,9 +1231,9 @@ static bool whisper_encode(
                         ggml_permute(ctxL,
                             ggml_reshape_3d(ctxL,
                                 Vcur,
-                                n_state/n_head, n_head, N),
+                                n_state/n_head, n_head, n_ctx),
                             1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
+                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
                         );
 
             struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@@ -1224,14 +1242,14 @@ static bool whisper_encode(
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Qcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Kcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
                         0, 2, 1, 3);
 
             // K * Q
@@ -1249,7 +1267,7 @@ static bool whisper_encode(
             //    ggml_permute(ctxL,
             //            ggml_cpy(ctxL,
             //                Vcur,
-            //                ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+            //                ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
             //            1, 2, 0, 3);
 
             //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@@ -1259,9 +1277,9 @@ static bool whisper_encode(
                         ggml_permute(ctxL,
                             ggml_reshape_3d(ctxL,
                                 Vcur,
-                                n_state/n_head, n_head, N),
+                                n_state/n_head, n_head, n_ctx),
                             0, 2, 1, 3),
-                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
+                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
                         );
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@@ -1271,7 +1289,7 @@ static bool whisper_encode(
 
             cur = ggml_cpy(ctxL,
                     KQV_merged,
-                    ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
+                    ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx));
         }
 
         // projection
@@ -1425,6 +1443,8 @@ static bool whisper_encode(
                         Vcross),
                     Vcross);
 
+            //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
+            //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
             struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
             struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
 
@@ -1474,7 +1494,8 @@ static bool whisper_decode(
     const int n_layer = hparams.n_text_layer;
 
     const int N = n_tokens;
-    const int M = hparams.n_audio_ctx;
+    //const int M = hparams.n_audio_ctx;
+    const int M = WHISPER_EXPERIMENT_AUDIO_CTX;
 
     struct ggml_init_params params = {
             .mem_size   = wctx.buf_compute.size(),
@@ -2662,7 +2683,7 @@ int whisper_full(
                 //}
 
                 // end of text token
-                if (token.id == whisper_token_eot(ctx)) {
+                if (token.id == whisper_token_eot(ctx) || (i > WHISPER_EXPERIMENT_MAX_TOKENS_PER_SEGMENT)) {
                     if (result_len == 0) {
                         if (seek + seek_delta + 100 >= seek_end) {
                             result_len = i + 1;
@@ -2671,6 +2692,12 @@ int whisper_full(
                             fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
                         }
                     }
+
+                    // TODO: TMP TO MAKE STREAM WORK ON RPI4 ===
+                    result_len = i + 1;
+                    seek_delta = 100*WHISPER_CHUNK_SIZE;
+                    // =========================================
+
                     break;
                 }
 
@@ -2850,7 +2877,7 @@ int whisper_full_parallel(
 
             // key/value memory for the cross-attention layer
             {
-                const int n_audio_ctx   = hparams.n_audio_ctx;
+                const int n_audio_ctx = hparams.n_audio_ctx;
 
                 const int n_mem      = n_text_layer*n_audio_ctx;
                 const int n_elements = n_text_state*n_mem;
index ea677eafd0c05a6b5a167917eef6c7bad12f270a..769ae643983f8bd8e86aad5021af8c26896206fc 100644 (file)
--- a/whisper.h
+++ b/whisper.h
@@ -24,6 +24,9 @@
 #define WHISPER_HOP_LENGTH  160
 #define WHISPER_CHUNK_SIZE  30
 
+#define WHISPER_EXPERIMENT_AUDIO_CTX 512
+#define WHISPER_EXPERIMENT_MAX_TOKENS_PER_SEGMENT 32
+
 #ifdef __cplusplus
 extern "C" {
 #endif