]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : various improvements
authorGeorgi Gerganov <redacted>
Wed, 5 Oct 2022 20:15:10 +0000 (23:15 +0300)
committerGeorgi Gerganov <redacted>
Wed, 5 Oct 2022 20:15:10 +0000 (23:15 +0300)
examples/whisper/main.cpp
examples/whisper/whisper.cpp
include/ggml/ggml.h
src/ggml.c

index 562559a19186632f814df1daadbc76dd323c2860..6d1c55dace01eb6966d959b804e0738df820ea55 100644 (file)
@@ -149,11 +149,11 @@ int main(int argc, char ** argv) {
         // convert to mono, float
         pcmf32.resize(n);
         if (wav.channels == 1) {
-            for (size_t i = 0; i < n; i++) {
+            for (int i = 0; i < n; i++) {
                 pcmf32[i] = float(pcm16[i])/32768.0f;
             }
         } else {
-            for (size_t i = 0; i < n; i++) {
+            for (int i = 0; i < n; i++) {
                 pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
             }
         }
@@ -185,6 +185,9 @@ int main(int argc, char ** argv) {
         wparams.print_progress       = false;
         wparams.print_timestamps     = !params.no_timestamps;
         wparams.print_special_tokens = params.print_special_tokens;
+        wparams.translate            = params.translate;
+        wparams.language             = params.language.c_str();
+        wparams.n_threads            = params.n_threads;
 
         if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
             fprintf(stderr, "%s: failed to process audio\n", argv[0]);
index 4f105eefe40755ca7793dd6d7a88e27d9cf622e8..46a4caa03238acd125e1d5cc51e09f7109bf24d3 100644 (file)
@@ -1031,8 +1031,6 @@ bool whisper_encode(
     const auto & mel_inp = wctx.mel;
     const auto & hparams = model.hparams;
 
-    const int n_vocab = hparams.n_vocab;
-
     const int n_ctx   = hparams.n_audio_ctx;
     const int n_state = hparams.n_audio_state;
     const int n_head  = hparams.n_audio_head;
@@ -1293,7 +1291,8 @@ bool whisper_encode(
         struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
 
         {
-            struct ggml_cgraph gf = { .n_threads = n_threads };
+            struct ggml_cgraph gf = {};
+            gf.n_threads = n_threads;
 
             ggml_build_forward_expand(&gf, inpO);
             ggml_graph_compute       (ctxL, &gf);
@@ -1329,7 +1328,8 @@ bool whisper_encode(
 
     // run the computation
     {
-        struct ggml_cgraph gf = { .n_threads = n_threads };
+        struct ggml_cgraph gf = {};
+        gf.n_threads = n_threads;
 
         ggml_build_forward_expand(&gf, cur);
         ggml_graph_compute       (ctx0, &gf);
@@ -1353,7 +1353,8 @@ bool whisper_encode(
 
     // pre-compute cross-attention memory
     {
-        struct ggml_cgraph gf = { .n_threads = n_threads };
+        struct ggml_cgraph gf = {};
+        gf.n_threads = n_threads;
 
         // TODO: hack to disconnect the encoded features from the previous graph
         cur->op = GGML_OP_NONE;
@@ -1463,7 +1464,8 @@ bool whisper_decode(
         };
 
         struct ggml_context * ctxL = ggml_init(paramsL);
-        struct ggml_cgraph gf = { .n_threads = n_threads };
+        struct ggml_cgraph gf = {};
+        gf.n_threads = n_threads;
 
         // norm
         {
@@ -1746,7 +1748,8 @@ bool whisper_decode(
 
     // run the computation
     {
-        struct ggml_cgraph gf = { .n_threads = n_threads };
+        struct ggml_cgraph gf = {};
+        gf.n_threads = n_threads;
 
         ggml_build_forward_expand(&gf, cur);
         ggml_graph_compute       (ctx0, &gf);
@@ -2336,7 +2339,7 @@ int whisper_full(
             }
         }
 
-        if (seek >= whisper_n_len(ctx)) {
+        if (seek + 100 >= whisper_n_len(ctx)) {
             break;
         }
 
@@ -2365,7 +2368,6 @@ int whisper_full(
 
         bool done = false;
         int seek_delta = 100*WHISPER_CHUNK_SIZE;
-        whisper_token last_id = 0;
 
         // print the prompt
         //printf("\n\n");
@@ -2395,8 +2397,6 @@ int whisper_full(
             // feel free to experiment!
             //
             {
-                const int n_vocab = whisper_n_vocab(ctx);
-
                 whisper_token id  = 0;
                 whisper_token tid = whisper_token_beg(ctx);
 
@@ -2410,7 +2410,6 @@ int whisper_full(
                     seek_delta = 2*(id - whisper_token_beg(ctx));
                     result_len = i + 1;
                 }
-                last_id = id;
 
                 // add it to the context
                 prompt.push_back(id);
@@ -2444,7 +2443,7 @@ int whisper_full(
 
             std::string text = "";
 
-            for (int i = 0; i < result_cur.size(); i++) {
+            for (int i = 0; i < (int) result_cur.size(); i++) {
                 if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) {
                 } else {
                     text += whisper_token_to_str(ctx, result_cur[i].id);
@@ -2464,7 +2463,7 @@ int whisper_full(
                         result_all.push_back({ t0, t1, text });
                     }
                     text = "";
-                    while (result_cur[i].id > whisper_token_beg(ctx) && i < result_cur.size()) {
+                    while (result_cur[i].id > whisper_token_beg(ctx) && i < (int) result_cur.size()) {
                         i++;
                     }
                     i--;
index 465a9b6d165ede8b5a26f12efb06dec448e45815..5b7b2582ef5b1eddda22984f432a697334a8d71e 100644 (file)
@@ -108,7 +108,7 @@ struct ggml_tensor {
     int64_t perf_time_us;
 
     void * data;
-    char pad[8];
+    char padding[8];
 };
 
 // computation graph
index 9b18d819cd35dccbf6f3099f1da62d528d3d93fd..58944893c519f7b376c168287b66aac6644a0ce2 100644 (file)
@@ -1,5 +1,6 @@
 #include "ggml.h"
 
+#include <alloca.h>
 #include <assert.h>
 #include <time.h>
 #include <math.h>
 #include <pthread.h>
 
 #define GGML_DEBUG 0
-#define GGML_MEM_ALIGN 16
+
+#if UINTPTR_MAX == 0xFFFFFFFF
+    #define GGML_MEM_ALIGN 4
+#else
+    #define GGML_MEM_ALIGN 16
+#endif
 
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -305,6 +311,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 #ifdef __ARM_NEON
     const int n32 = (n & ~31);
 
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
     float16x8_t sum0 = vdupq_n_f16(0);
     float16x8_t sum1 = vdupq_n_f16(0);
     float16x8_t sum2 = vdupq_n_f16(0);
@@ -344,6 +351,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 
     float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0f32), vget_high_f32(sum0f32));
     sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
+#else
+    float32x4_t sum0 = vdupq_n_f32(0);
+    float32x4_t sum1 = vdupq_n_f32(0);
+    float32x4_t sum2 = vdupq_n_f32(0);
+    float32x4_t sum3 = vdupq_n_f32(0);
+    float32x4_t sum4 = vdupq_n_f32(0);
+    float32x4_t sum5 = vdupq_n_f32(0);
+    float32x4_t sum6 = vdupq_n_f32(0);
+    float32x4_t sum7 = vdupq_n_f32(0);
+
+    float32x4_t x0, x1, x2, x3, x4, x5, x6, x7;
+    float32x4_t y0, y1, y2, y3, y4, y5, y6, y7;
+
+    for (int i = 0; i < n32; i += 32) {
+        x0 = vcvt_f32_f16(vld1_f16(x + i + 0 ));
+        x1 = vcvt_f32_f16(vld1_f16(x + i + 4 ));
+        x2 = vcvt_f32_f16(vld1_f16(x + i + 8 ));
+        x3 = vcvt_f32_f16(vld1_f16(x + i + 12));
+        x4 = vcvt_f32_f16(vld1_f16(x + i + 16));
+        x5 = vcvt_f32_f16(vld1_f16(x + i + 20));
+        x6 = vcvt_f32_f16(vld1_f16(x + i + 24));
+        x7 = vcvt_f32_f16(vld1_f16(x + i + 28));
+
+        y0 = vcvt_f32_f16(vld1_f16(y + i + 0 ));
+        y1 = vcvt_f32_f16(vld1_f16(y + i + 4 ));
+        y2 = vcvt_f32_f16(vld1_f16(y + i + 8 ));
+        y3 = vcvt_f32_f16(vld1_f16(y + i + 12));
+        y4 = vcvt_f32_f16(vld1_f16(y + i + 16));
+        y5 = vcvt_f32_f16(vld1_f16(y + i + 20));
+        y6 = vcvt_f32_f16(vld1_f16(y + i + 24));
+        y7 = vcvt_f32_f16(vld1_f16(y + i + 28));
+
+        sum0 = vfmaq_f32(sum0, x0, y0);
+        sum1 = vfmaq_f32(sum1, x1, y1);
+        sum2 = vfmaq_f32(sum2, x2, y2);
+        sum3 = vfmaq_f32(sum3, x3, y3);
+        sum4 = vfmaq_f32(sum4, x4, y4);
+        sum5 = vfmaq_f32(sum5, x5, y5);
+        sum6 = vfmaq_f32(sum6, x6, y6);
+        sum7 = vfmaq_f32(sum7, x7, y7);
+    }
+
+    // reduce sum0..sum7 to sum0
+    sum0 = vaddq_f32(sum0, sum1);
+    sum2 = vaddq_f32(sum2, sum3);
+    sum4 = vaddq_f32(sum4, sum5);
+    sum6 = vaddq_f32(sum6, sum7);
+    sum0 = vaddq_f32(sum0, sum2);
+    sum4 = vaddq_f32(sum4, sum6);
+    sum0 = vaddq_f32(sum0, sum4);
+
+    // reduce sum0 to sumf
+    float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0), vget_high_f32(sum0));
+    sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1);
+#endif
 
     // leftovers
     for (int i = n32; i < n; ++i) {
@@ -486,6 +548,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
     // NEON 128-bit
     const int n32 = (n & ~31);
 
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
     const float16x8_t v8 = vdupq_n_f16(v);
 
     float16x8_t x0, x1, x2, x3;
@@ -512,6 +575,51 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
         vst1q_f16(y + i + 16, y2);
         vst1q_f16(y + i + 24, y3);
     }
+#else
+    const float32x4_t v40 = vdupq_n_f32(v);
+    const float32x4_t v41 = vdupq_n_f32(v);
+
+    float32x4_t x0, x1, x2, x3, x4, x5, x6, x7;
+    float32x4_t y0, y1, y2, y3, y4, y5, y6, y7;
+
+    for (int i = 0; i < n32; i += 32) {
+        y0 = vcvt_f32_f16(vld1_f16(y + i + 0 ));
+        y1 = vcvt_f32_f16(vld1_f16(y + i + 4 ));
+        y2 = vcvt_f32_f16(vld1_f16(y + i + 8 ));
+        y3 = vcvt_f32_f16(vld1_f16(y + i + 12));
+        y4 = vcvt_f32_f16(vld1_f16(y + i + 16));
+        y5 = vcvt_f32_f16(vld1_f16(y + i + 20));
+        y6 = vcvt_f32_f16(vld1_f16(y + i + 24));
+        y7 = vcvt_f32_f16(vld1_f16(y + i + 28));
+
+        x0 = vcvt_f32_f16(vld1_f16(x + i + 0 ));
+        x1 = vcvt_f32_f16(vld1_f16(x + i + 4 ));
+        x2 = vcvt_f32_f16(vld1_f16(x + i + 8 ));
+        x3 = vcvt_f32_f16(vld1_f16(x + i + 12));
+        x4 = vcvt_f32_f16(vld1_f16(x + i + 16));
+        x5 = vcvt_f32_f16(vld1_f16(x + i + 20));
+        x6 = vcvt_f32_f16(vld1_f16(x + i + 24));
+        x7 = vcvt_f32_f16(vld1_f16(x + i + 28));
+
+        y0 = vfmaq_f32(y0, x0, v40);
+        y1 = vfmaq_f32(y1, x1, v40);
+        y2 = vfmaq_f32(y2, x2, v40);
+        y3 = vfmaq_f32(y3, x3, v40);
+        y4 = vfmaq_f32(y4, x4, v41);
+        y5 = vfmaq_f32(y5, x5, v41);
+        y6 = vfmaq_f32(y6, x6, v41);
+        y7 = vfmaq_f32(y7, x7, v41);
+
+        vst1_f16(y + i + 0 , vcvt_f16_f32(y0));
+        vst1_f16(y + i + 4 , vcvt_f16_f32(y1));
+        vst1_f16(y + i + 8 , vcvt_f16_f32(y2));
+        vst1_f16(y + i + 12, vcvt_f16_f32(y3));
+        vst1_f16(y + i + 16, vcvt_f16_f32(y4));
+        vst1_f16(y + i + 20, vcvt_f16_f32(y5));
+        vst1_f16(y + i + 24, vcvt_f16_f32(y6));
+        vst1_f16(y + i + 28, vcvt_f16_f32(y7));
+    }
+#endif
 
     // leftovers
     for (int i = n32; i < n; ++i) {
@@ -911,16 +1019,18 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
     if (is_first_call) {
         const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
 
+        ggml_fp16_t ii;
         for (int i = 0; i < (1 << 16); ++i) {
-            uint16_t ii = (uint16_t) i;
-            const float f = ggml_fp16_to_fp32(*(ggml_fp16_t *)(&ii));
+            uint16_t ui = i;
+            memcpy(&ii, &ui, sizeof(ii));
+            const float f = ggml_fp16_to_fp32(ii);
             table_gelu_f16[i] = ggml_fp32_to_fp16(ggml_gelu_f32(f));
             table_exp_f16[i] = ggml_fp32_to_fp16(exp(f));
         }
 
         const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 
-        GGML_PRINT_DEBUG("%s: GELU table initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+        GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
 
         is_first_call = false;
     }
@@ -4427,13 +4537,15 @@ void ggml_compute_forward_soft_max_f32(
 
         ggml_float sum = 0.0;
 
+        uint16_t ss;
         for (int i = 0; i < nc; i++) {
             if (p[i] == -INFINITY) {
                 p[i] = 0.0;
             } else {
                 //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
                 ggml_fp16_t s = ggml_fp32_to_fp16(p[i] - max);
-                const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]);
+                memcpy(&ss, &s, sizeof(ss));
+                const float val = ggml_fp16_to_fp32(table_exp_f16[ss]);
                 sum += val;
                 p[i] = val;
             }
@@ -5234,13 +5346,15 @@ void ggml_compute_forward_flash_attn_f32(
 
             ggml_float sum = 0.0;
 
+            uint16_t ss;
             for (int i = 0; i < M; i++) {
                 if (S[i] == -INFINITY) {
                     S[i] = 0.0;
                 } else {
                     //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
                     ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max);
-                    const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]);
+                    memcpy(&ss, &s, sizeof(ss));
+                    const float val = ggml_fp16_to_fp32(table_exp_f16[ss]);
                     sum += val;
                     S[i] = val;
                 }
@@ -5413,13 +5527,15 @@ void ggml_compute_forward_flash_attn_f16(
 
             ggml_float sum = 0.0;
 
+            uint16_t ss;
             for (int i = 0; i < M; i++) {
                 if (S[i] == -INFINITY) {
                     S[i] = 0.0;
                 } else {
                     //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
                     ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max);
-                    const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]);
+                    memcpy(&ss, &s, sizeof(ss));
+                    const float val = ggml_fp16_to_fp32(table_exp_f16[ss]);
                     sum += val;
                     S[i] = val;
                 }