]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Flash + language support (ref #2)
authorGeorgi Gerganov <redacted>
Wed, 28 Sep 2022 17:46:05 +0000 (20:46 +0300)
committerGeorgi Gerganov <redacted>
Wed, 28 Sep 2022 18:07:32 +0000 (21:07 +0300)
- Achieved big performance improvement + memory usage reduction
- Can now translate / transcribe different languages

Makefile
README.md
download-ggml-model.sh
ggml.c
ggml.h
main.cpp

index 773bde0e93a4f3e733d99a651ff7ac0f79293cac..1aed7bf5cac0d3c78b4ebb1fa244f6ac38d249ee 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -30,11 +30,16 @@ samples:
 # runs it on all samples in the folder "./samples":
 
 .PHONY: tiny.en
+.PHONY: tiny
 .PHONY: base.en
-.PHONY: medium.en
+.PHONY: base
 .PHONY: small.en
+.PHONY: small
+.PHONY: medium.en
+.PHONY: medium
+.PHONY: large
 
-tiny.en base.en medium.en small.en: main
+tiny.en tiny base.en base small.en small medium.en medium large: main
        bash ./download-ggml-model.sh $@
        @echo ""
        @echo "==============================================="
index 891a94a1dc8ae83598529ff6873d1e31cc467adf..f4877cf217a649928a51cd20133ccf303d651549 100644 (file)
--- a/README.md
+++ b/README.md
@@ -4,7 +4,8 @@ C/C++ port of [OpenAI's Whisper](https://github.com/openai/whisper) speech-to-te
 
 - Plain C/C++ implementation without dependencies
 - ARM_NEON and AVX intrinsics support
-- F16 support
+- Mixed F16 / F32 support
+- Low memory usage (Flash Attention + Flash Forward)
 
 ## Usage
 
@@ -27,9 +28,33 @@ For a quick demo, simply run `make base.en`:
 ```bash
 $ make base.en
 
-Downloading base.en (142 MB just once)
-mkdir -p models
-models/ggml-base.en.bin      100%[=================================>] 141.11M  7.50MB/s    in 19s
+gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c
+g++ -pthread -O3 -std=c++11 -c main.cpp
+g++ -o main ggml.o main.o
+./main -h
+
+usage: ./main [options]
+
+options:
+  -h,       --help           show this help message and exit
+  -s SEED,  --seed SEED      RNG seed (default: -1)
+  -t N,     --threads N      number of threads to use during computation (default: 4)
+  -T N,     --tokens N       maximum number of tokens to generate per iteration (default: 64)
+  -v,       --verbose        verbose output
+            --translate      translate from source language to english
+  -ps,      --print_special  print special tokens
+  -l LANG,  --language LANG  spoken language (default: en)
+  -m FNAME, --model FNAME    model path (default: models/ggml-base.en.bin)
+  -f FNAME, --file FNAME     input WAV file path (default: samples/jfk.wav)
+
+bash ./download-ggml-model.sh base.en
+Downloading ggml model base.en ...
+models/ggml-base.en.bin         100%[=====================================>] 141.11M  7.84MB/s    in 18s
+Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
+You can now use it like this:
+
+  $ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
+
 
 ===============================================
 Running base.en on all samples in ./samples ...
@@ -52,23 +77,24 @@ whisper_model_load: n_text_layer  = 6
 whisper_model_load: n_mels        = 80
 whisper_model_load: f16           = 1
 whisper_model_load: type          = 2
-whisper_model_load: mem_required  = 782.00 MB
+whisper_model_load: mem_required  = 611.00 MB
 whisper_model_load: adding 1607 extra tokens
-whisper_model_load: ggml ctx size = 186.26 MB
-whisper_model_load: memory size =    45.66 MB
+whisper_model_load: ggml ctx size = 163.43 MB
+whisper_model_load: memory size =    22.83 MB
 whisper_model_load: model size  =   140.54 MB
 log_mel_spectrogram: n_sample = 176000, n_len = 1100
 log_mel_spectrogram: recording length: 11.000000 s
 
- And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country.
+main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = transcribe ...
 
-main:     load time =    60.62 ms
-main:      mel time =    38.69 ms
-main:   sample time =     2.36 ms
-main:   encode time =   875.63 ms / 145.94 ms per layer
-main:   decode time =   103.17 ms
-main:    total time =  1081.13 ms
+ And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country.
 
+main:     load time =    71.89 ms
+main:      mel time =    36.95 ms
+main:   sample time =     2.10 ms
+main:   encode time =   700.94 ms / 116.82 ms per layer
+main:   decode time =    86.14 ms
+main:    total time =   898.72 ms
 ```
 
 The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
@@ -81,13 +107,18 @@ make samples
 
 This will download a few more audio files from Wikipedia and convert them to 16-bit WAV format via `ffmpeg`.
 
-You can download and run the other `.en` models as follows:
+You can download and run the other models as follows:
 
 ```
 make tiny.en
+make tiny
 make base.en
+make base
 make small.en
+make small
 make medium.en
+make medium
+make large
 ```
 
 For detailed usage instructions, run: `./main -h`
@@ -101,10 +132,8 @@ ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav
 
 ## Limitations
 
-- Only `.en` models are supported
 - Very basic greedy sampling scheme - always pick up the top token
 - No timestamps
-- English only
 - Inference only
 - Runs on the CPU
 - Only mono-channel 16-bit WAV is supported
@@ -113,10 +142,11 @@ ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav
 
 | Model | Disk | Mem |
 | ---   | --- | --- |
-| tiny.en | 75 MB | ~600 MB |
-| base.en | 142 MB | ~800 MB |
-| small.en | 466 MB | ~1.6 GB |
-| medium.en | 1.5 GB | ~3.5 GB |
+| tiny | 75 MB | ~460 MB |
+| base | 142 MB | ~620 MB |
+| small | 466 MB | ~1.3 GB |
+| medium | 1.5 GB | ~2.8 GB |
+| large | 2.9 GB | ~4.9 GB |
 
 ## ggml format
 
index 3d5fa50b82d84d119e91e056f97c0fb3f0014705..d3009d27c99044ec3dd692785e03fb374b39d4d1 100755 (executable)
@@ -6,7 +6,7 @@
 ggml_path=$(dirname $(realpath $0))
 
 # Whisper models
-models=( "tiny.en" "base.en" "small.en" "medium.en" )
+models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large" )
 
 # list available models
 function list_models {
diff --git a/ggml.c b/ggml.c
index c29422cef8ea3112bf5d35030653bc7539bd74b3..9b18d819cd35dccbf6f3099f1da62d528d3d93fd 100644 (file)
--- a/ggml.c
+++ b/ggml.c
 #define UNUSED(x) (void)(x)
 #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
 
-#define GGML_ASSERT(x) assert(x)
+#define GGML_ASSERT(x) \
+    do { \
+        if (!(x)) { \
+            fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+            abort(); \
+        } \
+    } while (0)
 
 #ifdef GGML_USE_ACCELERATE
 #include <Accelerate/Accelerate.h>
@@ -118,6 +124,16 @@ ggml_fp16_t ggml_fp32_to_fp16(float f) {
 }
 #endif
 
+//
+// global data
+//
+
+// precomputed gelu table for f16 (128 KB)
+static ggml_fp16_t table_gelu_f16[1 << 16];
+
+// precomputed exp table for f16 (128 KB)
+static ggml_fp16_t table_exp_f16[1 << 16];
+
 //
 // timing
 //
@@ -331,7 +347,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 
     // leftovers
     for (int i = n32; i < n; ++i) {
-        GGML_ASSERT(false); // should not end up here
         sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
     }
 #elif defined(__AVX2__)
@@ -375,7 +390,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 
     // leftovers
     for (int i = n32; i < n; ++i) {
-        GGML_ASSERT(false);
+        //GGML_ASSERT(false);
         sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
     }
 #else
@@ -558,12 +573,20 @@ inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) {
 const ggml_float GELU_COEF_A    = 0.044715;
 const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
 
-inline static void ggml_vec_gelu_f32 (const int n, float * y, const float * x) {
+inline static float ggml_gelu_f32(float x) {
+    return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
+}
+
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
     for (int i = 0; i < n; ++i) {
-        //y[i] = 0.5f*x[i]*(1.f + tanhf(SQRT_2_OVER_PI*(x[i] + 0.044715f*x[i]*x[i]*x[i])));
-        //0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))
-        const ggml_float xx = x[i];
-        y[i] = 0.5*xx*(1.0 + tanh(SQRT_2_OVER_PI*xx*(1.0 + GELU_COEF_A*xx*xx)));
+        y[i] = ggml_gelu_f32(x[i]);
+    }
+}
+
+inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        y[i] = table_gelu_f16[i16[i]];
     }
 }
 
@@ -641,6 +664,9 @@ const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "ROPE",
     "CONV_1D_1S",
     "CONV_1D_2S",
+
+    "FLASH_ATTN",
+    "FLASH_FF",
 };
 
 const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
@@ -678,6 +704,9 @@ const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "rope(x)",
     "conv_1d_1s(x)",
     "conv_1d_2s(x)",
+
+    "flash_attn(x)",
+    "flash_ff(x)",
 };
 
 //
@@ -878,6 +907,24 @@ int ggml_up64(int n) {
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_context * ggml_init(struct ggml_init_params params) {
+    static bool is_first_call = true;
+    if (is_first_call) {
+        const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
+
+        for (int i = 0; i < (1 << 16); ++i) {
+            uint16_t ii = (uint16_t) i;
+            const float f = ggml_fp16_to_fp32(*(ggml_fp16_t *)(&ii));
+            table_gelu_f16[i] = ggml_fp32_to_fp16(ggml_gelu_f32(f));
+            table_exp_f16[i] = ggml_fp32_to_fp16(exp(f));
+        }
+
+        const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+        GGML_PRINT_DEBUG("%s: GELU table initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+
+        is_first_call = false;
+    }
+
     // find non-used context in g_state
     struct ggml_context * ctx = NULL;
 
@@ -900,7 +947,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
     }
 
     if (ctx == NULL) {
-        GGML_PRINT_DEBUG("%s\n", "ggml_init: no unused context found");
+        GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
         return NULL;
     }
 
@@ -923,8 +970,8 @@ void ggml_free(struct ggml_context * ctx) {
         if (&g_state.contexts[i].context == ctx) {
             g_state.contexts[i].used = false;
 
-            GGML_PRINT_DEBUG("ggml_free: context %d with %d objects has been freed. memory used = %zu\n",
-                    i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
+            GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
+                    __func__, i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
 
             if (ctx->mem_buffer_owned) {
                 free(ctx->mem_buffer);
@@ -1010,6 +1057,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.grad         =*/ NULL,
         /*.src0         =*/ NULL,
         /*.src1         =*/ NULL,
+        /*.opt          =*/ { NULL },
         /*.n_tasks      =*/ 0,
         /*.perf_runs    =*/ 0,
         /*.perf_cycles  =*/ 0,
@@ -1079,6 +1127,14 @@ struct ggml_tensor * ggml_new_tensor_4d(
     return ggml_new_tensor(ctx, type, 4, ne);
 }
 
+struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
+
+    ggml_set_i32(result, value);
+
+    return result;
+}
+
 struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
 
@@ -1096,6 +1152,58 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
     return tensor;
 }
 
+struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
+    const int n     = ggml_nrows(tensor);
+    const int nc    = tensor->ne[0];
+    const size_t n1 = tensor->nb[1];
+
+    char * const data = tensor->data;
+
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                assert(tensor->nb[0] == sizeof(int8_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I16:
+            {
+                assert(tensor->nb[0] == sizeof(int16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I32:
+            {
+                assert(tensor->nb[0] == sizeof(int32_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F32:
+            {
+                assert(tensor->nb[0] == sizeof(float));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+
+    return tensor;
+}
+
 struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
     const int n     = ggml_nrows(tensor);
     const int nc    = tensor->ne[0];
@@ -1148,40 +1256,109 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
     return tensor;
 }
 
+int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                return ((int8_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_I16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+                return ((int16_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_I32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+                return ((int32_t *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+                return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
+                return ((float *)(tensor->data))[i];
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    return 0.0f;
+}
+
+void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                ((int8_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+                ((int16_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+                ((int32_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+                ((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
+                ((float *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
     switch (tensor->type) {
         case GGML_TYPE_I8:
             {
-                assert(tensor->nb[0] == sizeof(int8_t));
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
                 return ((int8_t *)(tensor->data))[i];
             } break;
         case GGML_TYPE_I16:
             {
-                assert(tensor->nb[0] == sizeof(int16_t));
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
                 return ((int16_t *)(tensor->data))[i];
             } break;
         case GGML_TYPE_I32:
             {
-                assert(tensor->nb[0] == sizeof(int32_t));
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
                 return ((int32_t *)(tensor->data))[i];
             } break;
         case GGML_TYPE_F16:
             {
-                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
                 return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]);
             } break;
         case GGML_TYPE_F32:
             {
-                assert(tensor->nb[0] == sizeof(float));
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 return ((float *)(tensor->data))[i];
             } break;
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 
-    assert(false);
     return 0.0f;
 }
 
@@ -1189,32 +1366,32 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
     switch (tensor->type) {
         case GGML_TYPE_I8:
             {
-                assert(tensor->nb[0] == sizeof(int8_t));
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
                 ((int8_t *)(tensor->data))[i] = value;
             } break;
         case GGML_TYPE_I16:
             {
-                assert(tensor->nb[0] == sizeof(int16_t));
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
                 ((int16_t *)(tensor->data))[i] = value;
             } break;
         case GGML_TYPE_I32:
             {
-                assert(tensor->nb[0] == sizeof(int32_t));
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
                 ((int32_t *)(tensor->data))[i] = value;
             } break;
         case GGML_TYPE_F16:
             {
-                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
                 ((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value);
             } break;
         case GGML_TYPE_F32:
             {
-                assert(tensor->nb[0] == sizeof(float));
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 ((float *)(tensor->data))[i] = value;
             } break;
         case GGML_TYPE_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     }
 }
@@ -2308,6 +2485,70 @@ struct ggml_tensor * ggml_conv_1d_2s(
     return result;
 }
 
+// ggml_flash_attn
+
+struct ggml_tensor * ggml_flash_attn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        bool                  masked) {
+    assert(ggml_can_mul_mat(k, q));
+    // TODO: check if vT can be multiplied by (k*qT)
+
+    bool is_node = false;
+
+    if (q->grad || k->grad || v->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne);
+
+    result->op   = GGML_OP_FLASH_ATTN;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = q;
+    result->src1 = k;
+    result->opt[0] = v;
+    result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0);
+
+    return result;
+}
+
+// ggml_flash_ff
+
+struct ggml_tensor * ggml_flash_ff(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b0,
+        struct ggml_tensor  * b1,
+        struct ggml_tensor  * c0,
+        struct ggml_tensor  * c1) {
+    assert(ggml_can_mul_mat(b0, a));
+    // TODO: more checks
+
+    bool is_node = false;
+
+    if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne);
+
+    result->op   = GGML_OP_FLASH_FF;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b0;
+    result->opt[0] = b1;
+    result->opt[1] = c0;
+    result->opt[2] = c1;
+
+    return result;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 void ggml_set_param(
@@ -2415,7 +2656,7 @@ void ggml_compute_forward_dup_f32(
             GGML_ASSERT(false); // TODO: implement
         }
     } else {
-        printf("%s: this is not optimal - fix me\n", __func__);
+        //printf("%s: this is not optimal - fix me\n", __func__);
 
         if (dst->type == GGML_TYPE_F32) {
             int id = 0;
@@ -4185,10 +4426,17 @@ void ggml_compute_forward_soft_max_f32(
         }
 
         ggml_float sum = 0.0;
+
         for (int i = 0; i < nc; i++) {
-            const ggml_float v = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
-            sum += v;
-            p[i] = v;
+            if (p[i] == -INFINITY) {
+                p[i] = 0.0;
+            } else {
+                //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
+                ggml_fp16_t s = ggml_fp32_to_fp16(p[i] - max);
+                const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]);
+                sum += val;
+                p[i] = val;
+            }
         }
 
         assert(sum > 0.0f);
@@ -4362,7 +4610,6 @@ void ggml_compute_forward_conv_1d_1s_f16_f32(
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    // WHISPER
     if (params->type == GGML_TASK_INIT) {
         // TODO: fix this memset (wsize is overestimated)
         memset(params->wdata, 0, params->wsize);
@@ -4483,7 +4730,6 @@ void ggml_compute_forward_conv_1d_1s_f32(
     GGML_ASSERT(nb00 == sizeof(float));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    // WHISPER
     if (params->type == GGML_TASK_INIT) {
         // TODO: fix this memset (wsize is overestimated)
         memset(params->wdata, 0, params->wsize);
@@ -4630,7 +4876,6 @@ void ggml_compute_forward_conv_1d_2s_f16_f32(
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    // WHISPER
     if (params->type == GGML_TASK_INIT) {
         // TODO: fix this memset (wsize is overestimated)
         memset(params->wdata, 0, params->wsize);
@@ -4751,7 +4996,6 @@ void ggml_compute_forward_conv_1d_2s_f32(
     GGML_ASSERT(nb00 == sizeof(float));
     GGML_ASSERT(nb10 == sizeof(float));
 
-    // WHISPER
     if (params->type == GGML_TASK_INIT) {
         // TODO: fix this memset (wsize is overestimated)
         memset(params->wdata, 0, params->wsize);
@@ -4841,6 +5085,607 @@ void ggml_compute_forward_conv_1d_2s(
     }
 }
 
+// ggml_compute_forward_flash_attn
+
+void ggml_compute_forward_flash_attn_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const bool masked,
+             struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int neq0 = q->ne[0];
+    const int neq1 = q->ne[1];
+    const int neq2 = q->ne[2];
+    const int neq3 = q->ne[3];
+
+    const int nek0 = k->ne[0];
+    const int nek1 = k->ne[1];
+    //const int nek2 = k->ne[2];
+    //const int nek3 = k->ne[3];
+
+    //const int nev0 = v->ne[0];
+    const int nev1 = v->ne[1];
+    //const int nev2 = v->ne[2];
+    //const int nev3 = v->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    //const int ne2  = dst->ne[2];
+    //const int ne3  = dst->ne[3];
+
+    const int nbk0 = k->nb[0];
+    const int nbk1 = k->nb[1];
+    const int nbk2 = k->nb[2];
+    const int nbk3 = k->nb[3];
+
+    const int nbq0 = q->nb[0];
+    const int nbq1 = q->nb[1];
+    const int nbq2 = q->nb[2];
+    const int nbq3 = q->nb[3];
+
+    const int nbv0 = v->nb[0];
+    const int nbv1 = v->nb[1];
+    const int nbv2 = v->nb[2];
+    const int nbv3 = v->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int D = neq0;
+    const int N = neq1;
+    const int P = nek1 - N;
+    const int M = P + N;
+
+    GGML_ASSERT(ne0 == D);
+    GGML_ASSERT(ne1 == N);
+    GGML_ASSERT(P >= 0);
+
+    GGML_ASSERT(nbq0 == sizeof(float));
+    GGML_ASSERT(nbk0 == sizeof(float));
+    GGML_ASSERT(nbv0 == sizeof(float));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev1 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nek1 == N + P);
+    GGML_ASSERT(nev1 == D);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (params->type == GGML_TASK_INIT) {
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq1*neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float scale = 1.0/sqrt((double) D);
+
+    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        float * S = (float *) params->wdata + ith*(M + CACHE_LINE_SIZE_F32);
+
+        for (int ic = 0; ic < nek1; ++ic) {
+            // k indices
+            const int ik3 = iq3;
+            const int ik2 = iq2;
+            const int ik1 = ic;
+
+            // S indices
+            const int i1 = ik1;
+
+            ggml_vec_dot_f32(neq0,
+                    S + i1,
+                    (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                    (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+        }
+
+        // scale
+        ggml_vec_scale_f32(nek1, S, scale);
+
+        if (masked) {
+            for (int i = P; i < M; i++) {
+                if (i > P + iq1) {
+                    S[i] = -INFINITY;
+                }
+            }
+        }
+
+        // softmax
+        {
+            float max = -INFINITY;
+            for (int i = 0; i < M; i++) {
+                max = MAX(max, S[i]);
+            }
+
+            ggml_float sum = 0.0;
+
+            for (int i = 0; i < M; i++) {
+                if (S[i] == -INFINITY) {
+                    S[i] = 0.0;
+                } else {
+                    //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
+                    ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max);
+                    const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]);
+                    sum += val;
+                    S[i] = val;
+                }
+            }
+
+            assert(sum > 0.0f);
+
+            sum = 1.0/sum;
+            ggml_vec_scale_f32(M, S, sum);
+        }
+
+        for (int ic = 0; ic < nev1; ++ic) {
+            // dst indices
+            const int i1 = iq1;
+            const int i2 = iq2;
+            const int i3 = iq3;
+
+            ggml_vec_dot_f32(nek1,
+                    (float *) ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                    (float *) ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                    S);
+        }
+    }
+}
+
+void ggml_compute_forward_flash_attn_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const bool masked,
+             struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int neq0 = q->ne[0];
+    const int neq1 = q->ne[1];
+    const int neq2 = q->ne[2];
+    const int neq3 = q->ne[3];
+
+    const int nek0 = k->ne[0];
+    const int nek1 = k->ne[1];
+    //const int nek2 = k->ne[2];
+    //const int nek3 = k->ne[3];
+
+    //const int nev0 = v->ne[0];
+    const int nev1 = v->ne[1];
+    //const int nev2 = v->ne[2];
+    //const int nev3 = v->ne[3];
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    //const int ne2  = dst->ne[2];
+    //const int ne3  = dst->ne[3];
+
+    const int nbk0 = k->nb[0];
+    const int nbk1 = k->nb[1];
+    const int nbk2 = k->nb[2];
+    const int nbk3 = k->nb[3];
+
+    const int nbq0 = q->nb[0];
+    const int nbq1 = q->nb[1];
+    const int nbq2 = q->nb[2];
+    const int nbq3 = q->nb[3];
+
+    const int nbv0 = v->nb[0];
+    const int nbv1 = v->nb[1];
+    const int nbv2 = v->nb[2];
+    const int nbv3 = v->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int D = neq0;
+    const int N = neq1;
+    const int P = nek1 - N;
+    const int M = P + N;
+
+    GGML_ASSERT(ne0 == D);
+    GGML_ASSERT(ne1 == N);
+    GGML_ASSERT(P >= 0);
+
+    GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev1 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nek1 == N + P);
+    GGML_ASSERT(nev1 == D);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (params->type == GGML_TASK_INIT) {
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq1*neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float scale = 1.0/sqrt((double) D);
+
+    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
+
+        for (int ic = 0; ic < nek1; ++ic) {
+            // k indices
+            const int ik3 = iq3;
+            const int ik2 = iq2;
+            const int ik1 = ic;
+
+            // S indices
+            const int i1 = ik1;
+
+            ggml_vec_dot_f16(neq0,
+                    S + i1,
+                    (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+                    (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+        }
+
+        // scale
+        ggml_vec_scale_f32(nek1, S, scale);
+
+        if (masked) {
+            for (int i = P; i < M; i++) {
+                if (i > P + iq1) {
+                    S[i] = -INFINITY;
+                }
+            }
+        }
+
+        // softmax
+        {
+            float max = -INFINITY;
+            for (int i = 0; i < M; i++) {
+                max = MAX(max, S[i]);
+            }
+
+            ggml_float sum = 0.0;
+
+            for (int i = 0; i < M; i++) {
+                if (S[i] == -INFINITY) {
+                    S[i] = 0.0;
+                } else {
+                    //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max);
+                    ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max);
+                    const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]);
+                    sum += val;
+                    S[i] = val;
+                }
+            }
+
+            assert(sum > 0.0f);
+
+            sum = 1.0/sum;
+            ggml_vec_scale_f32(M, S, sum);
+        }
+
+        ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
+
+        for (int i = 0; i < M; i++) {
+            S16[i] = ggml_fp32_to_fp16(S[i]);
+        }
+
+        for (int ic = 0; ic < nev1; ++ic) {
+            // dst indices
+            const int i1 = iq1;
+            const int i2 = iq2;
+            const int i3 = iq3;
+
+            ggml_vec_dot_f16(nek1,
+                    (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1  + i2*nb2  + i3*nb3)),
+                    (ggml_fp16_t *) ((char *) v->data   + (         ic*nbv1 + i2*nbv2 + i3*nbv3)),
+                    S16);
+        }
+    }
+}
+
+void ggml_compute_forward_flash_attn(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const bool masked,
+        struct ggml_tensor * dst) {
+    switch (q->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_flash_ff
+
+void ggml_compute_forward_flash_ff_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,  // F16
+        const struct ggml_tensor * b0, // F16 fc_w
+        const struct ggml_tensor * b1, // F32 fc_b
+        const struct ggml_tensor * c0, // F16 proj_w
+        const struct ggml_tensor * c1, // F32 proj_b
+        struct ggml_tensor * dst) {
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
+
+    const int nea0 = a->ne[0];
+    const int nea1 = a->ne[1];
+    const int nea2 = a->ne[2];
+    const int nea3 = a->ne[3];
+
+    const int neb00 = b0->ne[0];
+    const int neb01 = b0->ne[1];
+    //const int neb02 = b0->ne[2];
+    //const int neb03 = b0->ne[3];
+
+    const int neb10 = b1->ne[0];
+    const int neb11 = b1->ne[1];
+    //const int neb12 = b1->ne[2];
+    //const int neb13 = b1->ne[3];
+
+    const int nec00 = c0->ne[0];
+    const int nec01 = c0->ne[1];
+    //const int nec02 = c0->ne[2];
+    //const int nec03 = c0->ne[3];
+
+    const int nec10 = c1->ne[0];
+    const int nec11 = c1->ne[1];
+    //const int nec12 = c1->ne[2];
+    //const int nec13 = c1->ne[3];
+
+    const int ne0 = dst->ne[0];
+    const int ne1 = dst->ne[1];
+    const int ne2 = dst->ne[2];
+    //const int ne3 = dst->ne[3];
+
+    const int nba0 = a->nb[0];
+    const int nba1 = a->nb[1];
+    const int nba2 = a->nb[2];
+    const int nba3 = a->nb[3];
+
+    const int nbb00 = b0->nb[0];
+    const int nbb01 = b0->nb[1];
+    const int nbb02 = b0->nb[2];
+    const int nbb03 = b0->nb[3];
+
+    const int nbb10 = b1->nb[0];
+    //const int nbb11 = b1->nb[1];
+    //const int nbb12 = b1->nb[2];
+    //const int nbb13 = b1->nb[3];
+
+    const int nbc00 = c0->nb[0];
+    const int nbc01 = c0->nb[1];
+    const int nbc02 = c0->nb[2];
+    const int nbc03 = c0->nb[3];
+
+    const int nbc10 = c1->nb[0];
+    //const int nbc11 = c1->nb[1];
+    //const int nbc12 = c1->nb[2];
+    //const int nbc13 = c1->nb[3];
+
+    const int nb0 = dst->nb[0];
+    const int nb1 = dst->nb[1];
+    const int nb2 = dst->nb[2];
+    const int nb3 = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int D = nea0;
+    //const int N = nea1;
+    const int M = neb01;
+
+    GGML_ASSERT(ne0 == nea0);
+    GGML_ASSERT(ne1 == nea1);
+    GGML_ASSERT(ne2 == nea2);
+
+    GGML_ASSERT(nba0  == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbb10 == sizeof(float));
+    GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nbc10 == sizeof(float));
+
+    GGML_ASSERT(neb00 == D);
+    GGML_ASSERT(neb01 == M);
+    GGML_ASSERT(neb10 == M);
+    GGML_ASSERT(neb11 == 1);
+
+    GGML_ASSERT(nec00 == M);
+    GGML_ASSERT(nec01 == D);
+    GGML_ASSERT(nec10 == D);
+    GGML_ASSERT(nec11 == 1);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (params->type == GGML_TASK_INIT) {
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    // parallelize by a rows using ggml_vec_dot_f32
+
+    // total rows in a
+    const int nr = nea1*nea2*nea3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // a indices
+        const int ia3 = ir/(nea2*nea1);
+        const int ia2 = (ir - ia3*nea2*nea1)/nea1;
+        const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
+
+        float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
+
+        for (int ic = 0; ic < neb01; ++ic) {
+            // b0 indices
+            const int ib03 = ia3;
+            const int ib02 = ia2;
+            const int ib01 = ic;
+
+            // S indices
+            const int i1 = ib01;
+
+            ggml_vec_dot_f16(nea0,
+                    S + i1,
+                    (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
+                    (ggml_fp16_t *) ((char *)  a->data + ( ia1*nba1  +  ia2*nba2  +  ia3*nba3)));
+        }
+
+        ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
+        //ggml_vec_gelu_f32(neb01, S, S);
+
+        ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
+
+        for (int i = 0; i < M; i++) {
+            S16[i] = ggml_fp32_to_fp16(S[i]);
+        }
+
+        ggml_vec_gelu_f16(neb01, S16, S16);
+
+        {
+            // dst indices
+            const int i1 = ia1;
+            const int i2 = ia2;
+            const int i3 = ia3;
+
+            for (int ic = 0; ic < nec01; ++ic) {
+
+                ggml_vec_dot_f16(neb01,
+                        (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1   + i2*nb2   + i3*nb3)),
+                        (ggml_fp16_t *) ((char *) c0->data  + (         ic*nbc01 + i2*nbc02 + i3*nbc03)),
+                        S16);
+            }
+
+            ggml_vec_add_f32(nec01,
+                    (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
+                    (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
+                    (float *) c1->data);
+        }
+    }
+}
+
+void ggml_compute_forward_flash_ff(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        const struct ggml_tensor * b0,
+        const struct ggml_tensor * b1,
+        const struct ggml_tensor * c0,
+        const struct ggml_tensor * c1,
+        struct ggml_tensor * dst) {
+    switch (b0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(false); // TODO
+            } break;
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                assert(false);
+            } break;
+    }
+}
+
 /////////////////////////////////
 
 void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -4967,13 +5812,24 @@ void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tenso
             {
                 ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_FLASH_ATTN:
+            {
+                int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
+                GGML_ASSERT(t == 0 || t == 1);
+                bool masked = t != 0;
+                ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
+            } break;
+        case GGML_OP_FLASH_FF:
+            {
+                ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
+            } break;
         case GGML_OP_NONE:
             {
                 // nop
             } break;
         case GGML_OP_COUNT:
             {
-                assert(false);
+                GGML_ASSERT(false);
             } break;
     };
 }
@@ -5205,6 +6061,14 @@ void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tenso
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_FLASH_ATTN:
+            {
+                GGML_ASSERT(false); // not supported
+            } break;
+        case GGML_OP_FLASH_FF:
+            {
+                GGML_ASSERT(false); // not supported
+            } break;
         case GGML_OP_NONE:
             {
                 // nop
@@ -5246,6 +6110,12 @@ void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node)
         ggml_visit_parents(cgraph, node->src1);
     }
 
+    for (int i = 0; i < GGML_MAX_OPT; ++i) {
+        if (node->opt[i]) {
+            ggml_visit_parents(cgraph, node->opt[i]);
+        }
+    }
+
     if (node->op == GGML_OP_NONE && node->grad == NULL) {
         // reached a leaf node, not part of the gradient graph (e.g. a constant)
         assert(cgraph->n_leafs < GGML_MAX_NODES);
@@ -5591,7 +6461,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 case GGML_OP_CONV_1D_1S:
                 case GGML_OP_CONV_1D_2S:
                     {
-                        // WHISPER
                         node->n_tasks = n_threads;
 
                         GGML_ASSERT(node->src0->ne[3] == 1);
@@ -5617,6 +6486,42 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                             GGML_ASSERT(false);
                         }
 
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_FLASH_ATTN:
+                    {
+                        node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+
+                        if (node->src1->type == GGML_TYPE_F32) {
+                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+                        }
+
+                        if (node->src1->type == GGML_TYPE_F16) {
+                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+                        }
+
+                        work_size = MAX(work_size, cur);
+                    } break;
+                case GGML_OP_FLASH_FF:
+                    {
+                        node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+
+                        if (node->src1->type == GGML_TYPE_F32) {
+                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+                        }
+
+                        if (node->src1->type == GGML_TYPE_F16) {
+                            cur  = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
+                        }
+
                         work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_NONE:
diff --git a/ggml.h b/ggml.h
index 1078fbe82d2ee8984559452a0a479743ba555a88..465a9b6d165ede8b5a26f12efb06dec448e45815 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -12,6 +12,7 @@ extern "C" {
 #define GGML_MAX_NODES    4096
 #define GGML_MAX_PARAMS   16
 #define GGML_MAX_CONTEXTS 16
+#define GGML_MAX_OPT      4
 
 #ifdef __ARM_NEON
 // we use the built-in 16-bit float type
@@ -71,6 +72,9 @@ enum ggml_op {
     GGML_OP_CONV_1D_1S,
     GGML_OP_CONV_1D_2S,
 
+    GGML_OP_FLASH_ATTN,
+    GGML_OP_FLASH_FF,
+
     GGML_OP_COUNT,
 };
 
@@ -93,6 +97,7 @@ struct ggml_tensor {
     struct ggml_tensor * grad;
     struct ggml_tensor * src0;
     struct ggml_tensor * src1;
+    struct ggml_tensor * opt[GGML_MAX_OPT];
 
     // thread scheduling
     int n_tasks;
@@ -182,14 +187,19 @@ struct ggml_tensor * ggml_new_tensor_4d(
         int    ne2,
         int    ne3);
 
+struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
 struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
 
 struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
 struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
 
 struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
+struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
 struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
 
+int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
+void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
+
 float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
 void  ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
 
@@ -399,6 +409,21 @@ struct ggml_tensor * ggml_conv_1d_2s(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b);
 
+struct ggml_tensor * ggml_flash_attn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        bool                  masked);
+
+struct ggml_tensor * ggml_flash_ff(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b0,
+        struct ggml_tensor  * b1,
+        struct ggml_tensor  * c0,
+        struct ggml_tensor  * c1);
+
 //
 // automatic differentiation
 //
index 40835ba72eaebf2e8acce30504de3b39cee673c9..326a8a70d96d7c10d78fa72dff895c5ec808334a 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -1,5 +1,8 @@
 #include "ggml.h"
 
+#define USE_FLASH_ATTN
+#define USE_FLASH_FF
+
 // third-party utilities
 // use your favorite implementations
 #define DR_WAV_IMPLEMENTATION
@@ -16,6 +19,7 @@
 #include <thread>
 #include <vector>
 
+// available whisper models
 enum e_model {
     MODEL_UNKNOWN,
     MODEL_TINY,
@@ -25,14 +29,116 @@ enum e_model {
     MODEL_LARGE,
 };
 
+const std::map<std::string, std::pair<int, std::string>> g_lang = {
+    { "en",  { 0,  "english",         } },
+    { "zh",  { 1,  "chinese",         } },
+    { "de",  { 2,  "german",          } },
+    { "es",  { 3,  "spanish",         } },
+    { "ru",  { 4,  "russian",         } },
+    { "ko",  { 5,  "korean",          } },
+    { "fr",  { 6,  "french",          } },
+    { "ja",  { 7,  "japanese",        } },
+    { "pt",  { 8,  "portuguese",      } },
+    { "tr",  { 9,  "turkish",         } },
+    { "pl",  { 10, "polish",          } },
+    { "ca",  { 11,  "catalan",        } },
+    { "nl",  { 12,  "dutch",          } },
+    { "ar",  { 13,  "arabic",         } },
+    { "sv",  { 14,  "swedish",        } },
+    { "it",  { 15,  "italian",        } },
+    { "id",  { 16,  "indonesian",     } },
+    { "hi",  { 17,  "hindi",          } },
+    { "fi",  { 18,  "finnish",        } },
+    { "vi",  { 19,  "vietnamese",     } },
+    { "iw",  { 20,  "hebrew",         } },
+    { "uk",  { 21,  "ukrainian",      } },
+    { "el",  { 22,  "greek",          } },
+    { "ms",  { 23,  "malay",          } },
+    { "cs",  { 24,  "czech",          } },
+    { "ro",  { 25,  "romanian",       } },
+    { "da",  { 26,  "danish",         } },
+    { "hu",  { 27,  "hungarian",      } },
+    { "ta",  { 28,  "tamil",          } },
+    { "no",  { 29,  "norwegian",      } },
+    { "th",  { 30,  "thai",           } },
+    { "ur",  { 31,  "urdu",           } },
+    { "hr",  { 32,  "croatian",       } },
+    { "bg",  { 33,  "bulgarian",      } },
+    { "lt",  { 34,  "lithuanian",     } },
+    { "la",  { 35,  "latin",          } },
+    { "mi",  { 36,  "maori",          } },
+    { "ml",  { 37,  "malayalam",      } },
+    { "cy",  { 38,  "welsh",          } },
+    { "sk",  { 39,  "slovak",         } },
+    { "te",  { 40,  "telugu",         } },
+    { "fa",  { 41,  "persian",        } },
+    { "lv",  { 42,  "latvian",        } },
+    { "bn",  { 43,  "bengali",        } },
+    { "sr",  { 44,  "serbian",        } },
+    { "az",  { 45,  "azerbaijani",    } },
+    { "sl",  { 46,  "slovenian",      } },
+    { "kn",  { 47,  "kannada",        } },
+    { "et",  { 48,  "estonian",       } },
+    { "mk",  { 49,  "macedonian",     } },
+    { "br",  { 50,  "breton",         } },
+    { "eu",  { 51,  "basque",         } },
+    { "is",  { 52,  "icelandic",      } },
+    { "hy",  { 53,  "armenian",       } },
+    { "ne",  { 54,  "nepali",         } },
+    { "mn",  { 55,  "mongolian",      } },
+    { "bs",  { 56,  "bosnian",        } },
+    { "kk",  { 57,  "kazakh",         } },
+    { "sq",  { 58,  "albanian",       } },
+    { "sw",  { 59,  "swahili",        } },
+    { "gl",  { 60,  "galician",       } },
+    { "mr",  { 61,  "marathi",        } },
+    { "pa",  { 62,  "punjabi",        } },
+    { "si",  { 63,  "sinhala",        } },
+    { "km",  { 64,  "khmer",          } },
+    { "sn",  { 65,  "shona",          } },
+    { "yo",  { 66,  "yoruba",         } },
+    { "so",  { 67,  "somali",         } },
+    { "af",  { 68,  "afrikaans",      } },
+    { "oc",  { 69,  "occitan",        } },
+    { "ka",  { 70,  "georgian",       } },
+    { "be",  { 71,  "belarusian",     } },
+    { "tg",  { 72,  "tajik",          } },
+    { "sd",  { 73,  "sindhi",         } },
+    { "gu",  { 74,  "gujarati",       } },
+    { "am",  { 75,  "amharic",        } },
+    { "yi",  { 76,  "yiddish",        } },
+    { "lo",  { 77,  "lao",            } },
+    { "uz",  { 78,  "uzbek",          } },
+    { "fo",  { 79,  "faroese",        } },
+    { "ht",  { 80,  "haitian creole", } },
+    { "ps",  { 81,  "pashto",         } },
+    { "tk",  { 82,  "turkmen",        } },
+    { "nn",  { 83,  "nynorsk",        } },
+    { "mt",  { 84,  "maltese",        } },
+    { "sa",  { 85,  "sanskrit",       } },
+    { "lb",  { 86,  "luxembourgish",  } },
+    { "my",  { 87,  "myanmar",        } },
+    { "bo",  { 88,  "tibetan",        } },
+    { "tl",  { 89,  "tagalog",        } },
+    { "mg",  { 90,  "malagasy",       } },
+    { "as",  { 91,  "assamese",       } },
+    { "tt",  { 92,  "tatar",          } },
+    { "haw", { 93,  "hawaiian",       } },
+    { "ln",  { 94,  "lingala",        } },
+    { "ha",  { 95,  "hausa",          } },
+    { "ba",  { 96,  "bashkir",        } },
+    { "jw",  { 97,  "javanese",       } },
+    { "su",  { 98,  "sundanese",      } },
+};
+
 const size_t MB = 1024*1024;
 
 const std::map<e_model, size_t> MEM_REQ_MODEL = {
-    { MODEL_TINY,    100ull*MB },
-    { MODEL_BASE,    190ull*MB },
-    { MODEL_SMALL,   610ull*MB },
-    { MODEL_MEDIUM, 1900ull*MB },
-    { MODEL_LARGE,  3600ull*MB },
+    { MODEL_TINY,     86ull*MB },
+    { MODEL_BASE,    165ull*MB },
+    { MODEL_SMALL,   540ull*MB },
+    { MODEL_MEDIUM, 1650ull*MB },
+    { MODEL_LARGE,  3260ull*MB },
 };
 
 const std::map<e_model, size_t> MEM_REQ_ENCODE = {
@@ -44,11 +150,11 @@ const std::map<e_model, size_t> MEM_REQ_ENCODE = {
 };
 
 const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
-    { MODEL_TINY,    170ull*MB },
-    { MODEL_BASE,    230ull*MB },
-    { MODEL_SMALL,   350ull*MB },
-    { MODEL_MEDIUM,  450ull*MB },
-    { MODEL_LARGE,   570ull*MB },
+    { MODEL_TINY,     64ull*MB },
+    { MODEL_BASE,     84ull*MB },
+    { MODEL_SMALL,   128ull*MB },
+    { MODEL_MEDIUM,  172ull*MB },
+    { MODEL_LARGE,   216ull*MB },
 };
 
 const std::map<e_model, size_t> MEM_REQ_DECODE = {
@@ -102,6 +208,10 @@ struct whisper_vocab {
     id token_solm = 50361; // ??
     id token_beg  = 50363;
 
+    // available tasks
+    const id token_translate  = 50358;
+    const id token_transcribe = 50359;
+
     bool is_multilingual() const {
         return n_vocab == 51865;
     }
@@ -109,16 +219,18 @@ struct whisper_vocab {
 
 // command-line parameters
 struct whisper_params {
-    int32_t seed      = -1; // RNG seed
+    int32_t seed      = -1; // RNG seed, not used currently
     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
 
+    // sampling parameter - used for the greedy strategy
     int32_t max_tokens_per_iter = 64;
 
-    bool verbose = false;
+    bool verbose              = false;
+    bool translate            = false;
     bool print_special_tokens = false;
 
-    std::string model = "models/ggml-base.en.bin"; // model path
-
+    std::string language  = "en";
+    std::string model     = "models/ggml-base.en.bin";
     std::string fname_inp = "samples/jfk.wav";
 };
 
@@ -136,6 +248,15 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
             params.max_tokens_per_iter = std::stoi(argv[++i]);
         } else if (arg == "-v" || arg == "--verbose") {
             params.verbose = true;
+        } else if (arg == "--translate") {
+            params.translate = true;
+        } else if (arg == "-l" || arg == "--language") {
+            params.language = argv[++i];
+            if (g_lang.find(params.language) == g_lang.end()) {
+                fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
+                whisper_print_usage(argc, argv, params);
+                exit(0);
+            }
         } else if (arg == "-ps" || arg == "--print_special") {
             params.print_special_tokens = true;
         } else if (arg == "-m" || arg == "--model") {
@@ -160,16 +281,16 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
     fprintf(stderr, "usage: %s [options]\n", argv[0]);
     fprintf(stderr, "\n");
     fprintf(stderr, "options:\n");
-    fprintf(stderr, "  -h, --help            show this help message and exit\n");
-    fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1)\n");
-    fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
-    fprintf(stderr, "  -T N, --tokens N      maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter);
-    fprintf(stderr, "  -v, --verbose         verbose output\n");
-    fprintf(stderr, "  -ps, --print_special  print special tokens\n");
-    fprintf(stderr, "  -m FNAME, --model FNAME\n");
-    fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
-    fprintf(stderr, "  -f FNAME, --file FNAME\n");
-    fprintf(stderr, "                        input WAV file path (default: %s)\n", params.fname_inp.c_str());
+    fprintf(stderr, "  -h,       --help           show this help message and exit\n");
+    fprintf(stderr, "  -s SEED,  --seed SEED      RNG seed (default: -1)\n");
+    fprintf(stderr, "  -t N,     --threads N      number of threads to use during computation (default: %d)\n", params.n_threads);
+    fprintf(stderr, "  -T N,     --tokens N       maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter);
+    fprintf(stderr, "  -v,       --verbose        verbose output\n");
+    fprintf(stderr, "            --translate      translate from source language to english\n");
+    fprintf(stderr, "  -ps,      --print_special  print special tokens\n");
+    fprintf(stderr, "  -l LANG,  --language LANG  spoken language (default: %s)\n", params.language.c_str());
+    fprintf(stderr, "  -m FNAME, --model FNAME    model path (default: %s)\n", params.model.c_str());
+    fprintf(stderr, "  -f FNAME, --file FNAME     input WAV file path (default: %s)\n", params.fname_inp.c_str());
     fprintf(stderr, "\n");
 }
 
@@ -417,6 +538,7 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
         printf("%s: f16           = %d\n", __func__, hparams.f16);
         printf("%s: type          = %d\n", __func__, model.type);
 
+        // this is the total memory required to run the inference
         const size_t mem_required =
                    MEM_REQ_MODEL.at(model.type) +
                   MEM_REQ_ENCODE.at(model.type) +
@@ -609,11 +731,11 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
             ctx_size += n_text_layer*(             n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
         }
 
-        ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_k
-        ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_v
+        ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
+        ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
 
-        ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_cross_k
-        ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_cross_v
+        ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
+        ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
 
         ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
 
@@ -836,22 +958,24 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
         const int n_text_layer = hparams.n_text_layer;
         const int n_text_ctx   = hparams.n_text_ctx;
 
+        // key/value memory for the self-attention layer
         {
             const int n_mem      = n_text_layer*n_text_ctx;
             const int n_elements = n_text_state*n_mem;
 
-            model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
-            model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+            model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+            model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
         }
 
+        // key/value memory for the cross-attention layer
         {
             const int n_audio_ctx   = hparams.n_audio_ctx;
 
             const int n_mem      = n_text_layer*n_audio_ctx;
             const int n_elements = n_text_state*n_mem;
 
-            model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
-            model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+            model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+            model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
         }
 
         const size_t memory_size =
@@ -1057,14 +1181,14 @@ bool whisper_encode(
                         Qcur),
                     Qcur);
 
-            Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+            //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
 
-            // no bias for Key
+            // note: no bias for Key
             struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
                     layer.attn_k_w,
                     cur);
 
-            Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
+            //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
 
             struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
                     layer.attn_v_w,
@@ -1078,49 +1202,57 @@ bool whisper_encode(
 
             // ------
 
+#ifdef USE_FLASH_ATTN
             struct ggml_tensor * Q =
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Qcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
                         0, 2, 1, 3);
 
             struct ggml_tensor * K =
                 ggml_permute(ctxL,
                         ggml_cpy(ctxL,
                             Kcur,
-                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), // F16 !
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
                         0, 2, 1, 3);
 
-            //// BLAS attempt
-            //struct ggml_tensor * KQ =
-            //    ggml_mul_mat(ctxL,
-            //        ggml_cpy(ctxL, K, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, N, n_head)),
-            //        ggml_cpy(ctxL, Q, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, N, n_head)));
+            struct ggml_tensor * V =
+                ggml_cpy(ctxL,
+                        ggml_permute(ctxL,
+                            ggml_reshape_3d(ctxL,
+                                Vcur,
+                                n_state/n_head, n_head, N),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
+                        );
 
-            // K * Q
-            struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
+            struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
+#else
+            struct ggml_tensor * Q =
+                ggml_permute(ctxL,
+                        ggml_cpy(ctxL,
+                            Qcur,
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
+                        0, 2, 1, 3);
 
-            //struct ggml_tensor * K =
-            //    ggml_cpy(ctxL,
-            //            ggml_permute(ctxL,
-            //                ggml_reshape_3d(ctxL,
-            //                    Kcur,
-            //                    n_state/n_head, n_head, N),
-            //                1, 2, 0, 3),
-            //            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
-            //            );
+            struct ggml_tensor * K =
+                ggml_permute(ctxL,
+                        ggml_cpy(ctxL,
+                            Kcur,
+                            ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
+                        0, 2, 1, 3);
 
-            //// K * Q
-            //struct ggml_tensor * KQ = ggml_mul_mat(ctxL, ggml_transpose(ctxL, K), Q);
+            // K * Q
+            struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
 
-            //struct ggml_tensor * KQ_scaled =
-            //    ggml_scale(ctxL,
-            //            KQ,
-            //            ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
-            //            );
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctxL,
+                        KQ,
+                        ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
+                        );
 
-            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
 
             //struct ggml_tensor * V_trans =
             //    ggml_permute(ctxL,
@@ -1138,10 +1270,11 @@ bool whisper_encode(
                                 Vcur,
                                 n_state/n_head, n_head, N),
                             0, 2, 1, 3),
-                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head) // F16 !
+                        ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
                         );
 
             struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
+#endif
 
             struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
 
@@ -1180,6 +1313,11 @@ bool whisper_encode(
                         ggml_repeat(ctxL, layer.mlp_ln_b, cur));
             }
 
+#ifdef USE_FLASH_FF
+            cur = ggml_flash_ff(ctxL,
+                    ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
+                    layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
+#else
             // fully connected
             cur = ggml_mul_mat(ctxL,
                     layer.mlp_0_w,
@@ -1200,6 +1338,7 @@ bool whisper_encode(
             cur = ggml_add(ctxL,
                     ggml_repeat(ctxL, layer.mlp_1_b, cur),
                     cur);
+#endif
         }
 
         // output from this layer
@@ -1368,7 +1507,7 @@ bool whisper_decode(
         ((int32_t *) position->data)[i] = n_past + i;
     }
 
-    // wte + wpe
+    // token encoding + position encoding
     struct ggml_tensor * cur =
         ggml_add(ctx0,
                 ggml_get_rows(ctx0, model.d_te, embd),
@@ -1420,7 +1559,7 @@ bool whisper_decode(
 
             Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
 
-            // no bias for Key
+            // note: no bias for Key
             struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
                     layer.attn_k_w,
                     cur);
@@ -1506,7 +1645,7 @@ bool whisper_decode(
 
         // norm
         {
-            cur = ggml_norm(ctxL, inpCA); // Note we use inpCA here
+            cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
 
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctxL,
@@ -1589,7 +1728,6 @@ bool whisper_decode(
                     cur);
         }
 
-
         // add the input
         cur = ggml_add(ctxL, cur, inpCA);
 
@@ -1601,8 +1739,7 @@ bool whisper_decode(
             {
                 cur = ggml_norm(ctxL, inpFF);
 
-                // cur = ln_2_g*cur + ln_2_b
-                // [ 768, N]
+                // cur = mlp_ln_w*cur + mlp_ln_b
                 cur = ggml_add(ctxL,
                         ggml_mul(ctxL,
                             ggml_repeat(ctxL, layer.mlp_ln_w, cur),
@@ -1689,11 +1826,11 @@ bool whisper_decode(
     probs_out.resize(N*n_vocab);
     memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
 
-    //if (N > 1) {
-    //    const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
-    //    printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
-    //    printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
-    //}
+    if (N > 1) {
+        //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
+        //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
+        //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
+    }
 
     ggml_free(ctx0);
 
@@ -1981,8 +2118,36 @@ int main(int argc, char ** argv) {
         t_mel_us = ggml_time_us() - t_start_us;
     }
 
+    // print some info about the processing
+    {
+        printf("\n");
+        if (!vocab.is_multilingual()) {
+            if (params.language != "en" || params.translate) {
+                params.language = "en";
+                params.translate = false;
+                printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
+            }
+        }
+        printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s ...\n",
+                __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
+                g_lang.at(params.language).second.c_str(),
+                params.translate ? "translate" : "transcribe");
+    }
+
+    // the accumulated text context so far
     std::vector<whisper_vocab::id> prompt_past = { };
 
+    // these tokens determine the task that will be performed
+    std::vector<whisper_vocab::id> prompt_init = { vocab.token_sot };
+    if (vocab.is_multilingual()) {
+        prompt_init.push_back(vocab.token_sot + 1 + g_lang.at(params.language).first);
+        if (params.translate) {
+            prompt_init.push_back(vocab.token_translate);
+        } else {
+            prompt_init.push_back(vocab.token_transcribe);
+        }
+    }
+
     // main loop
     int seek = 0;
     while (true) {
@@ -2006,24 +2171,23 @@ int main(int argc, char ** argv) {
         std::vector<float> probs;
         std::vector<float> logits;
 
-        // SOT
-        // ref: https://github.com/openai/whisper/blob/15ab54826343c27cfaf44ce31e9c8fb63d0aa775/whisper/decoding.py#L506-L526
-        // TODO: use different initial tokens for different tasks
-        std::vector<whisper_vocab::id> prompt = { vocab.token_sot };
+        std::vector<whisper_vocab::id> prompt;
 
         int n_past = 0;
 
+        // if we have already generated some text, use it as a prompt to condition the next generation
         if (prompt_past.size() > 0) {
             int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size()));
 
             prompt = { vocab.token_prev };
-            prompt.insert(prompt.end(), prompt_past.end() - n_take, prompt_past.end());
-            prompt.push_back(vocab.token_sot);
+            prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
 
             prompt_past.clear();
-            prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - 1);
+            prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
         }
 
+        prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
+
         bool done = false;
         int seek_delta = 100*CHUNK_SIZE;
         whisper_vocab::id last_id = 0;
@@ -2049,6 +2213,16 @@ int main(int argc, char ** argv) {
             n_past += prompt.size();
             prompt.clear();
 
+            // very basic greedy sampling strategy:
+            //
+            //   - always take the most probable token
+            //   - if we have accumulated more than 'params.max_tokens_per_iter' -> pick most probable timestamp token
+            //     and advance the sliding window by that amount
+            //   - in the meantime, if we encounter 2 consecutive timestamp tokens, we advance the sliding window too
+            //
+            // more sophisticated sampling strategies could be implemented here, but we keep it simple
+            // feel free to experiment!
+            //
             {
                 // sample next token
                 const float temp  = 1.0; // TODO