]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
whisper : token-level timestamps with DTW (whisper/1485)
authordenersc <redacted>
Wed, 20 Mar 2024 16:25:26 +0000 (13:25 -0300)
committerGeorgi Gerganov <redacted>
Wed, 27 Mar 2024 11:35:37 +0000 (13:35 +0200)
* whisper.cpp: impl dtw algo

* WIP: producing and placing DTW timestamps on tokens

* Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false.

* Fix mistake causing incorrect alignment of dtw timestamps

* implement N_TOP_MOST and CUSTOM alignment heads setting

* whisper: fix typo on alignment heads enum

* Fix issues related to changes in whisper.cpp

* Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function

* decoder: save cross QKs only if requested

* Calling median filter with ggml_map_custom1

* Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads

* Copying cross QKs from decoder backend correctly

* dtw: cleanup

* Fix incorrect n_frames passed to dtw when near end of audio

* Fix aheads_masks_init for backend != CPU

* whisper : minor style

* main : add dtw (wip)

* whisper: fix invalid memory access in aheads_masks_init

* main : add dtw (cont)

* whisper : minor

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/whisper/main.cpp
examples/whisper/whisper.cpp
examples/whisper/whisper.h

index 50185cc0a854dc6418d010798613ddc67459d99a..caa800b3df8ba45445809a2f7873099866e509ca 100644 (file)
@@ -26,17 +26,17 @@ void replace_all(std::string & s, const std::string & search, const std::string
 
 // command-line parameters
 struct whisper_params {
-    int32_t n_threads    = std::min(4, (int32_t) std::thread::hardware_concurrency());
-    int32_t n_processors  1;
-    int32_t offset_t_ms   0;
-    int32_t offset_n      0;
-    int32_t duration_ms   0;
-    int32_t progress_step =  5;
-    int32_t max_context  = -1;
-    int32_t max_len       0;
-    int32_t best_of      = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
-    int32_t beam_size    = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
-    int32_t audio_ctx   = 0;
+    int32_t n_threads     = std::min(4, (int32_t) std::thread::hardware_concurrency());
+    int32_t n_processors  = 1;
+    int32_t offset_t_ms   = 0;
+    int32_t offset_n      = 0;
+    int32_t duration_ms   = 0;
+    int32_t progress_step = 5;
+    int32_t max_context   = -1;
+    int32_t max_len       = 0;
+    int32_t best_of       = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
+    int32_t beam_size     = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
+    int32_t audio_ctx     = 0;
 
     float word_thold    =  0.01f;
     float entropy_thold =  2.40f;
@@ -76,6 +76,8 @@ struct whisper_params {
 
     std::string openvino_encode_device = "CPU";
 
+    std::string dtw = "";
+
     std::vector<std::string> fname_inp = {};
     std::vector<std::string> fname_out = {};
 };
@@ -149,6 +151,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
         else if (arg == "-m"    || arg == "--model")           { params.model           = argv[++i]; }
         else if (arg == "-f"    || arg == "--file")            { params.fname_inp.emplace_back(argv[++i]); }
         else if (arg == "-oved" || arg == "--ov-e-device")     { params.openvino_encode_device = argv[++i]; }
+        else if (arg == "-dtw"  || arg == "--dtw")             { params.dtw             = argv[++i]; }
         else if (arg == "-ls"   || arg == "--log-score")       { params.log_score       = true; }
         else if (arg == "-ng"   || arg == "--no-gpu")          { params.use_gpu         = false; }
         else {
@@ -208,6 +211,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
     fprintf(stderr, "  -m FNAME,  --model FNAME       [%-7s] model path\n",                                     params.model.c_str());
     fprintf(stderr, "  -f FNAME,  --file FNAME        [%-7s] input WAV file path\n",                            "");
     fprintf(stderr, "  -oved D,   --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n",  params.openvino_encode_device.c_str());
+    fprintf(stderr, "  -dtw MODEL --dtw MODEL         [%-7s] compute token-level timestamps\n",                 params.dtw.c_str());
     fprintf(stderr, "  -ls,       --log-score         [%-7s] log best decoder scores of tokens\n",              params.log_score?"true":"false");
     fprintf(stderr, "  -ng,       --no-gpu            [%-7s] disable GPU\n",                                    params.use_gpu ? "false" : "true");
     fprintf(stderr, "\n");
@@ -649,7 +653,8 @@ bool output_json(
                                     times_o(token.t0, token.t1, false);
                                 }
                                 value_i("id", token.id, false);
-                                value_f("p", token.p, true);
+                                value_f("p", token.p, false);
+                                value_f("t_dtw", token.t_dtw, true);
                             end_obj(j == (n - 1));
                         }
                         end_arr(!params.diarize && !params.tinydiarize);
@@ -889,6 +894,28 @@ int main(int argc, char ** argv) {
     struct whisper_context_params cparams = whisper_context_default_params();
     cparams.use_gpu = params.use_gpu;
 
+    if (!params.dtw.empty()) {
+        cparams.dtw_token_timestamps = true;
+        cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
+
+        if (params.dtw == "tiny")      cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
+        if (params.dtw == "tiny.en")   cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
+        if (params.dtw == "base")      cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
+        if (params.dtw == "base.en")   cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
+        if (params.dtw == "small")     cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
+        if (params.dtw == "small.en")  cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
+        if (params.dtw == "medium")    cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
+        if (params.dtw == "medium.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
+        if (params.dtw == "large.v1")  cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
+        if (params.dtw == "large.v2")  cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
+        if (params.dtw == "large.v3")  cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
+
+        if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
+            fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
+            return 3;
+        }
+    }
+
     struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
 
     if (ctx == nullptr) {
index fb2f5eb393b4906914dde4a88fcbe122d9007cb1..a43e50daa3bfaa2db55dd6c9363a2a7151ebc2ae 100644 (file)
@@ -351,6 +351,35 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
     { "yue", { 99,  "cantonese",      } },
 };
 
+// [EXPERIMENTAL] Token-level timestamps with DTW
+static const whisper_ahead g_aheads_tiny_en[]   = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} };
+static const whisper_ahead g_aheads_tiny[]      = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} };
+static const whisper_ahead g_aheads_base_en[]   = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} };
+static const whisper_ahead g_aheads_base[]      = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} };
+static const whisper_ahead g_aheads_small_en[]  = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} };
+static const whisper_ahead g_aheads_small[]     = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} };
+static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} };
+static const whisper_ahead g_aheads_medium[]    = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} };
+static const whisper_ahead g_aheads_large_v1[]  = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
+static const whisper_ahead g_aheads_large_v2[]  = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
+static const whisper_ahead g_aheads_large_v3[]  = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
+
+static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
+    { WHISPER_AHEADS_TINY_EN,   {  8, g_aheads_tiny_en   } },
+    { WHISPER_AHEADS_TINY,      {  6, g_aheads_tiny      } },
+    { WHISPER_AHEADS_BASE_EN,   {  5, g_aheads_base_en   } },
+    { WHISPER_AHEADS_BASE,      {  8, g_aheads_base      } },
+    { WHISPER_AHEADS_SMALL_EN,  { 19, g_aheads_small_en  } },
+    { WHISPER_AHEADS_SMALL,     { 10, g_aheads_small     } },
+    { WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } },
+    { WHISPER_AHEADS_MEDIUM,    {  6, g_aheads_medium    } },
+    { WHISPER_AHEADS_LARGE_V1,  {  9, g_aheads_large_v1  } },
+    { WHISPER_AHEADS_LARGE_V2,  { 23, g_aheads_large_v2  } },
+    { WHISPER_AHEADS_LARGE_V3,  { 10, g_aheads_large_v3  } },
+};
+
+static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
+
 struct whisper_mel {
     int n_len;
     int n_len_org;
@@ -750,6 +779,13 @@ struct whisper_decoder {
     mutable std::mt19937 rng; // used for sampling at t > 0.0
 };
 
+// [EXPERIMENTAL] Token-level timestamps with DTW
+struct whisper_aheads_masks {
+    std::vector<struct ggml_tensor *> m;    // One mask per text layer.
+    struct ggml_context * ctx = nullptr;
+    ggml_backend_buffer_t buffer = nullptr;
+};
+
 struct whisper_state {
     int64_t t_sample_us = 0;
     int64_t t_encode_us = 0;
@@ -823,6 +859,11 @@ struct whisper_state {
 
     std::vector<float> energy; // PCM signal energy
 
+    // [EXPERIMENTAL] Token-level timestamps with DTW
+    whisper_aheads_masks aheads_masks;
+    ggml_tensor * aheads_cross_QKs = nullptr;
+    std::vector<float> aheads_cross_QKs_data;
+
     // [EXPERIMENTAL] speed-up techniques
     int32_t exp_n_audio_ctx = 0; // 0 - use default
 };
@@ -1027,6 +1068,132 @@ static void whisper_kv_cache_seq_cp(
     }
 }
 
+// [EXPERIMENTAL] Token-level timestamps with DTW
+static bool aheads_masks_init(
+        const whisper_context_params & cparams,
+               const whisper_hparams & hparams,
+         struct whisper_aheads_masks & aheads_masks,
+                      ggml_backend_t   backend) {
+
+    const int32_t n_text_layer = hparams.n_text_layer;
+    const int32_t n_head = hparams.n_text_head;
+
+    // Sanity checks
+    if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
+        WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__);
+        return false;
+    } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
+        if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) {
+            WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer);
+            return false;
+        }
+    } else {
+        const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
+        if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) {
+            if (aheads.n_heads == 0) {
+                WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__);
+                return false;
+            }
+            if (aheads.heads == NULL) {
+                WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__);
+                return false;
+            }
+        }
+        for (size_t i = 0; i < aheads.n_heads; ++i) {
+            if (aheads.heads[i].n_text_layer >= n_text_layer) {
+                WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer);
+                return false;
+            }
+            if (aheads.heads[i].n_text_layer < 0) {
+                WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__);
+                return false;
+            }
+            if (aheads.heads[i].n_head >= n_head) {
+                WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head);
+                return false;
+            }
+            if (aheads.heads[i].n_head < 0) {
+                WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__);
+                return false;
+            }
+        }
+    }
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ (size_t) static_cast<size_t>(n_text_layer)*ggml_tensor_overhead(),
+        /*.mem_buffer =*/ nullptr,
+        /*.no_alloc   =*/ true,
+    };
+
+    aheads_masks.ctx = ggml_init(params);
+
+    if (!aheads_masks.ctx) {
+        WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__);
+        return false;
+    }
+
+    for (int64_t il = 0; il < n_text_layer; ++il) {
+        auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
+        if (!aheads.empty()) {
+            aheads_masks.m.push_back(ggml_new_tensor_2d(aheads_masks.ctx, GGML_TYPE_F32, n_head, aheads.size()));
+        } else {
+            aheads_masks.m.push_back(nullptr);
+        }
+    }
+
+    aheads_masks.buffer = ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend);
+    if (!aheads_masks.buffer) {
+        WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__);
+        return false;
+    }
+
+    // Set data on mask tensors
+    // Since this must be backend agnostic, we get tensor data with
+    // ggml_backend_tensor_get, copy our desired values and send it back
+    // to backend with ggml_backend_tensor_set
+    std::vector<float> mask_data;
+    for (int64_t il = 0; il < n_text_layer; ++il) {
+        if (aheads_masks.m[il] != nullptr) {
+            auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
+
+            size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1] * sizeof(float);
+            mask_data.resize(data_size);
+            ggml_backend_tensor_get(aheads_masks.m[il], mask_data.data(), 0, data_size);
+            memset(mask_data.data(), 0, data_size);
+
+            for (size_t ih = 0; ih < aheads.size(); ++ih) {
+                size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0] * aheads[ih]));
+                float v = 1.0f;
+                memcpy(mask_data.data() + pos, &v, sizeof(float));
+            }
+
+            ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size);
+        }
+    }
+
+    if (aheads_masks.m.empty()) {
+        WHISPER_LOG_ERROR("%s: \n", __func__);
+        return false;
+    }
+
+    return true;
+}
+
+static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) {
+    ggml_free(aheads_masks.ctx);
+    ggml_backend_buffer_free(aheads_masks.buffer);
+    aheads_masks.ctx = nullptr;
+}
+
+static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
+    size_t size = 0;
+    for (size_t i = 0; i < aheads_masks.m.size(); ++i) {
+        if (aheads_masks.m[i] != nullptr)
+            size += ggml_nbytes(aheads_masks.m[i]);
+    }
+    return size;
+}
+
 static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
     ggml_backend_t backend_gpu = NULL;
 
@@ -2105,6 +2272,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
          whisper_context & wctx,
          whisper_state   & wstate,
      const whisper_batch & batch,
+                    bool   save_alignment_heads_QKs,
                     bool   worst_case) {
     const auto & model   = wctx.model;
     const auto & hparams = model.hparams;
@@ -2158,6 +2326,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
     struct ggml_tensor * inpL = cur;
 
+    // [EXPERIMENTAL] Token-level timestamps with DTW
+    struct ggml_tensor * aheads_cross_QKs = nullptr;
+
     for (int il = 0; il < n_layer; ++il) {
         const auto & layer = model.layers_decoder[il];
 
@@ -2337,6 +2508,24 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
 
+            // [EXPERIMENTAL] Token-level timestamps with DTW
+            if (wctx.params.dtw_token_timestamps) {
+                if (wstate.aheads_masks.m[il] != nullptr) {
+                    struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
+                    aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
+                    aheads_KQs = ggml_cont(ctx0, aheads_KQs);
+                    aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
+                    aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
+                    aheads_KQs = ggml_cont(ctx0, aheads_KQs);
+                    aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
+                    if (aheads_cross_QKs == NULL) {
+                        aheads_cross_QKs = aheads_KQs;
+                    } else {
+                        aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
+                    }
+                }
+            }
+
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@@ -2422,6 +2611,16 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
     struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
 
+    // [EXPERIMENTAL] Token-level timestamps with DTW
+    if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) {
+        aheads_cross_QKs = ggml_transpose(ctx0, aheads_cross_QKs);
+        aheads_cross_QKs = ggml_cont(ctx0, aheads_cross_QKs);
+        if (save_alignment_heads_QKs) {
+            ggml_build_forward_expand(gf, aheads_cross_QKs);
+            wstate.aheads_cross_QKs = aheads_cross_QKs;
+        }
+    }
+
     ggml_build_forward_expand(gf, logits);
 
     ggml_free(ctx0);
@@ -2444,6 +2643,7 @@ static bool whisper_decode_internal(
           whisper_state & wstate,
     const whisper_batch & batch,
               const int   n_threads,
+                   bool   save_alignment_heads_QKs,
     ggml_abort_callback   abort_callback,
                    void * abort_callback_data) {
     const int64_t t_start_us = ggml_time_us();
@@ -2475,7 +2675,7 @@ static bool whisper_decode_internal(
     {
         auto & alloc = wstate.alloc_decode.alloc;
 
-        ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, false);
+        ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
 
         if (!ggml_gallocr_alloc_graph(alloc, gf)) {
             // should never happen as we pre-allocate the memory
@@ -3003,6 +3203,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
     }
 
+    // [EXPERIMENTAL] Token-level timestamps with DTW
+    if (ctx->params.dtw_token_timestamps) {
+        if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
+            WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
+            whisper_free_state(state);
+            return nullptr;
+        }
+        const size_t memory_size = aheads_masks_nbytes(state->aheads_masks);
+        WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size);
+    }
+
 #ifdef WHISPER_USE_COREML
     const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
 
@@ -3095,7 +3306,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
                     whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
 
-                    return whisper_build_graph_decoder(*ctx, *state, state->batch, true);
+                    return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true);
                 });
 
         if (!ok) {
@@ -3161,8 +3372,17 @@ int whisper_ctx_init_openvino_encoder(
 
 struct whisper_context_params whisper_context_default_params() {
     struct whisper_context_params result = {
-        /*.use_gpu    =*/ true,
-        /*.gpu_device =*/ 0,
+        /*.use_gpu              =*/ true,
+        /*.gpu_device           =*/ 0,
+
+        /*.dtw_token_timestamps =*/ false,
+        /*.dtw_aheads_preset    =*/ WHISPER_AHEADS_NONE,
+        /*.dtw_n_top            =*/ -1,
+        /*.dtw_aheads           =*/ {
+            /*.n_heads          =*/ 0,
+            /*.heads            =*/ NULL,
+        },
+        /*.dtw_mem_size         =*/ 1024*1024*128,
     };
     return result;
 }
@@ -3357,6 +3577,9 @@ void whisper_free_state(struct whisper_state * state) {
 
         ggml_backend_free(state->backend);
 
+        // [EXPERIMENTAL] Token-level timestamps with DTW
+        aheads_masks_free(state->aheads_masks);
+
         delete state;
     }
 }
@@ -3476,7 +3699,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
 
     whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
 
-    if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) {
+    if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) {
         WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
         return 1;
     }
@@ -4411,6 +4634,17 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
     return txt[0] == ' ';
 }
 
+static void whisper_exp_compute_token_level_timestamps_dtw(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+        struct whisper_full_params   params,
+                               int   i_segment,
+                            size_t   n_segments,
+                               int   seek,
+                               int   n_frames,
+                               int   medfilt_width,
+                               int   n_threads);
+
 // wrap the last segment to max_len characters
 // returns the number of new segments
 static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
@@ -4779,7 +5013,7 @@ static whisper_token_data whisper_sample_token(
       const whisper_decoder & decoder,
                        bool   best) {
     whisper_token_data result = {
-        0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
+        0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f,
     };
 
     const auto & vocab = ctx.vocab;
@@ -4897,7 +5131,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
         const auto id = dist(decoder.rng);
         //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
 
-        result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
+        result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, });
 
         if (result[i].id >= vocab.token_beg) {
             result[i].tid = result[i].id;
@@ -5259,7 +5493,7 @@ int whisper_full_with_state(
 
                 whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
 
-                if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
+                if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
                     WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
                     return -7;
                 }
@@ -5559,7 +5793,7 @@ int whisper_full_with_state(
 
                     assert(batch.n_tokens > 0);
 
-                    if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
+                    if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
                         WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
                         return -8;
                     }
@@ -5682,6 +5916,9 @@ int whisper_full_with_state(
 
             const auto & tokens_cur = best_decoder.sequence.tokens;
 
+            // [EXPERIMENTAL] Token-level timestamps with DTW
+            const auto n_segments_before = state->result_all.size();
+
             //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
 
             // update prompt_past
@@ -5799,6 +6036,17 @@ int whisper_full_with_state(
                 }
             }
 
+            // FIXME: will timestamp offsets be correct?
+            // [EXPERIMENTAL] Token-level timestamps with DTW
+            {
+                const auto n_segments = state->result_all.size() - n_segments_before;
+                if (ctx->params.dtw_token_timestamps && n_segments) {
+                    const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
+                    whisper_exp_compute_token_level_timestamps_dtw(
+                            ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
+                }
+            }
+
             // update audio window
             seek += seek_delta;
 
@@ -6601,6 +6849,321 @@ static void whisper_exp_compute_token_level_timestamps(
     //}
 }
 
+//
+// token level timestamps - dtw version
+//
+
+// n_text_layer -> total text layers on model
+// n_head -> total heads per text layer on model
+static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) {
+    std::vector<uint32_t> ret;
+    if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
+        return ret;
+    } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
+        if (il >= n_text_layer - cparams.dtw_n_top) {
+            for (int32_t i = 0; i < n_head; ++i) {
+                ret.push_back(i);
+            }
+        }
+    } else {
+        const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
+        for (size_t i = 0; i < aheads.n_heads; ++i) {
+            if (aheads.heads[i].n_text_layer == il) {
+                ret.push_back(aheads.heads[i].n_head);
+            }
+        }
+    }
+    return ret;
+}
+
+// dtw + backtrace to return found path
+// based on
+// https://github.com/openai/whisper/blob/main/whisper/timing.py#L83
+static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
+    WHISPER_ASSERT(ggml_n_dims(x) == 2);
+
+    int64_t N = x->ne[0];
+    int64_t M = x->ne[1];
+    struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1);
+    struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
+
+    cost = ggml_set_f32(cost, INFINITY);
+    trace = ggml_set_f32(trace, -1);
+    ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
+
+    // dtw
+    // supposedly can be optmized by computing diagonals in parallel ?
+    // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
+    for (int64_t j = 1; j < M + 1; ++j) {
+        for (int64_t i = 1; i < N + 1; ++i) {
+            float c0 = ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0);
+            float c1 = ggml_get_f32_nd(cost, i - 1, j, 0, 0);
+            float c2 = ggml_get_f32_nd(cost, i, j - 1, 0, 0);
+
+            float c;
+            int32_t t;
+            if (c0 < c1 && c0 < c2) {
+                c = c0;
+                t = 0;
+            } else if (c1 < c0 && c1 < c2) {
+                c = c1;
+                t = 1;
+            } else {
+                c = c2;
+                t = 2;
+            }
+
+            c = ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
+            ggml_set_f32_nd(cost, i, j, 0, 0, c);
+            ggml_set_i32_nd(trace, i, j, 0, 0, t);
+        }
+    }
+
+    // Backtrace
+    const int64_t BT_MAX_ROWS = N + M - 1;
+    struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2);
+    // trace[0, :] = 2;
+    for (int64_t i = 0; i < M + 1; ++i)
+        ggml_set_i32_nd(trace, 0, i, 0, 0, 2);
+    //trace[:, 0] = 1;
+    for (int64_t i = 0; i < N + 1; ++i)
+        ggml_set_i32_nd(trace, i, 0, 0, 0, 1);
+    int bt_row_idx = BT_MAX_ROWS - 1;
+    int64_t i = N;
+    int64_t j = M;
+    while (i > 0 || j > 0) {
+        ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
+        ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
+        --bt_row_idx;
+
+        int32_t t = ggml_get_i32_nd(trace, i, j, 0, 0);
+        if (t == 0) {
+            --i;
+            --j;
+        } else if (t == 1) {
+            --i;
+        } else if (t == 2) {
+            --j;
+        } else {
+            WHISPER_ASSERT(0);
+        }
+    }
+
+    // FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs)
+    // Clip + transpose
+    // This might not be entirely necessary for our case, but leaving it for now so output matrix
+    // is identical to dtw on openAI timing.py
+    const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1;
+    ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols);
+    for (int64_t i = 0; i < 2; ++i) {
+        for (int64_t j = 0; j < result_n_cols; ++j) {
+            int32_t v = ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
+            ggml_set_i32_nd(r, i, j, 0, 0, v);
+        }
+    }
+
+    return r;
+}
+
+struct median_filter_user_data {
+    int filter_width;
+};
+
+static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata) {
+    int filter_width = ((median_filter_user_data *) userdata)->filter_width;
+    WHISPER_ASSERT(nth == 1);
+    WHISPER_ASSERT(ith == 0);
+    WHISPER_ASSERT(filter_width < a->ne[2]);
+    WHISPER_ASSERT(filter_width % 2);
+    WHISPER_ASSERT(ggml_n_dims(a) == 3);
+    WHISPER_ASSERT(a->type == GGML_TYPE_F32);
+
+    std::vector<float> filter;
+    filter.reserve(filter_width);
+    for (int64_t i = 0; i < a->ne[0]; ++i) {
+        for (int64_t j = 0; j < a->ne[1]; ++j) {
+            for (int64_t k = 0; k < a->ne[2]; ++k) {
+                for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) {
+                    // "reflect" padding
+                    int64_t idx = k + off;
+                    if (idx < 0) {
+                        idx = -idx;
+                    } else if (idx >= a->ne[2]) {
+                        idx = 2*(a->ne[2] - 1) - idx;
+                    }
+
+                    filter.push_back(ggml_get_f32_nd(a, i, j, idx, 0));
+                }
+                std::sort(filter.begin(), filter.end());
+                const float v = filter[filter.size()/2];
+                ggml_set_f32_nd(dst, i, j, k, 0, v);
+                filter.clear();
+            }
+        }
+    }
+}
+
+static void whisper_exp_compute_token_level_timestamps_dtw(
+            struct whisper_context * ctx,
+              struct whisper_state * state,
+        struct whisper_full_params   params,
+                               int   i_segment,
+                            size_t   n_segments,
+                               int   seek,
+                               int   n_frames,
+                               int   medfilt_width,
+                               int   n_threads)
+{
+    const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx;
+    WHISPER_ASSERT(medfilt_width % 2);
+    WHISPER_ASSERT(n_frames <= n_audio_ctx * 2);
+    WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE);
+
+    // FIXME: Allocating mem everytime we call this func
+    // Our ggml buffer should be pre-allocated somewhere during init and reused
+    // when we call this function
+    struct ggml_init_params gparams = {
+        /*.mem_size   =*/ ctx->params.dtw_mem_size,
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ false,
+    };
+    struct ggml_context * gctx = ggml_init(gparams);
+
+    // Build token sequence that will be passed to decoder
+    // sot + [lang] + text result + eot
+    std::vector<whisper_token> tokens = { whisper_token_sot(ctx), };
+    if (whisper_is_multilingual(ctx)) {
+        const int lang_id = whisper_lang_id(params.language);
+        state->lang_id = lang_id;
+        tokens.push_back(whisper_token_lang(ctx, lang_id));
+    }
+    const size_t sot_sequence_length = tokens.size();
+    tokens.push_back(whisper_token_not(ctx));
+    for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
+        auto & segment = state->result_all[i];
+        for (auto &t: segment.tokens) {
+            // Only text tokens
+            if (t.id < whisper_token_eot(ctx)) {
+                tokens.push_back(t.id);
+            }
+        }
+    }
+    tokens.push_back(whisper_token_eot(ctx));
+
+    // Get result tokens, pass then along to decoder to get cross attention QKs
+    // used in timestamping
+    // Decoder already returns only alignment head QKs, already concatenated in
+    // one tensor.
+    whisper_kv_cache_clear(state->kv_self);
+    whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0);
+    whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1);
+    if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) {
+        WHISPER_LOG_INFO("DECODER FAILED\n");
+        WHISPER_ASSERT(0);
+    }
+    WHISPER_ASSERT(state->aheads_cross_QKs != nullptr);
+
+    const auto n_audio_tokens = n_frames/2;
+    WHISPER_ASSERT(state->aheads_cross_QKs != NULL);
+    WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]);
+    const auto n_tokens = state->aheads_cross_QKs->ne[0];
+    const auto n_heads = state->aheads_cross_QKs->ne[2];
+
+    // Copy data from decoder buffer to a local CPU tensor, discarding unused audio
+    // tokens (i.e. discarding rows at the end of tensor)
+    // IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims
+    // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
+    WHISPER_ASSERT(state->aheads_cross_QKs->type == GGML_TYPE_F32);
+    WHISPER_ASSERT(ggml_is_contiguous(state->aheads_cross_QKs));
+    ggml_tensor * w = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads);
+    auto & data = state->aheads_cross_QKs_data;
+    data.resize(n_tokens * n_audio_ctx * n_heads);
+    ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads);
+    for (int k = 0; k < n_heads; ++k) {
+        for (int j = 0; j < n_audio_tokens; ++j) {
+            memcpy(
+                (char *) w->data + j * w->nb[1] + k * w->nb[2],
+                data.data() + j * n_tokens + k * n_tokens * n_audio_ctx,
+                n_tokens * sizeof(float)
+            );
+        }
+    }
+
+    // Normalize - in original OpenAI code, this is done over dim=-2. In this case,
+    // we already permuted N_TOKENS dimension to columns on last loop, becase ggml_norm
+    // operates over columns. Afterwards, permute to a shape that facilitates mean
+    // operation (after median filter)
+    // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
+    // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
+    w = ggml_norm(gctx, w, 1e-9);
+    w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
+
+    // Pass median filter - this is done over AUDIO_TOKENS dimension.
+    // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
+    // OUT: Same dims
+    median_filter_user_data mf_user_data = {medfilt_width};
+    w = ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data);
+
+    // Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT
+    // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
+    // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims
+    w = ggml_mean(gctx, w);
+    w = ggml_scale(gctx, w, -1.0);
+    w = ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]);
+
+    // Remove SOT sequence and EOT
+    // Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS
+    w = ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]);
+
+    // Compute
+    struct ggml_cgraph * gf = ggml_new_graph(gctx);
+    ggml_build_forward_expand(gf, w);
+    ggml_graph_compute_with_ctx(gctx, gf, n_threads);
+
+    ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
+
+    // Place timestamps on segments
+    int32_t last_v = 0;
+    auto seg_i = state->result_all.begin() + i_segment;
+    auto tok_i = seg_i->tokens.begin();
+    for (int i = 0; i < alignment->ne[1]; ++i) {
+        int32_t v = ggml_get_i32_nd(alignment, 0, i, 0, 0);
+        if (v != last_v) {
+            int32_t time_index = ggml_get_i32_nd(alignment, 1, i, 0, 0);
+            int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
+            last_v = v;
+
+            // Skip non-text tokens
+            while (!(tok_i->id < whisper_token_eot(ctx))) {
+                ++tok_i;
+                if (tok_i == seg_i->tokens.end()) {
+                    ++seg_i;
+                    tok_i = seg_i->tokens.begin();
+                }
+            }
+
+            tok_i->t_dtw = timestamp;
+            ++tok_i;
+            if (tok_i == seg_i->tokens.end()) {
+                ++seg_i;
+                tok_i = seg_i->tokens.begin();
+            }
+        }
+    }
+
+    // Print DTW timestamps
+    /*for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
+        auto & segment = state->result_all[i];
+        for (auto &t: segment.tokens) {
+            const char * tok = whisper_token_to_str(ctx, t.id);
+            fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100);
+        }
+        fprintf(stderr, "\n");
+    }*/
+
+    ggml_free(gctx);
+}
+
 void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
     g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
     g_state.log_callback_user_data = user_data;
index a5371eb3b9331a49f7479f8376df5a7d82c42c25..2754337f14307c577bcc43a3cd0d8b489083a37a 100644 (file)
@@ -84,9 +84,45 @@ extern "C" {
     typedef int32_t whisper_token;
     typedef int32_t whisper_seq_id;
 
+    enum whisper_alignment_heads_preset {
+        WHISPER_AHEADS_NONE,
+        WHISPER_AHEADS_N_TOP_MOST,  // All heads from the N-top-most text-layers
+        WHISPER_AHEADS_CUSTOM,
+        WHISPER_AHEADS_TINY_EN,
+        WHISPER_AHEADS_TINY,
+        WHISPER_AHEADS_BASE_EN,
+        WHISPER_AHEADS_BASE,
+        WHISPER_AHEADS_SMALL_EN,
+        WHISPER_AHEADS_SMALL,
+        WHISPER_AHEADS_MEDIUM_EN,
+        WHISPER_AHEADS_MEDIUM,
+        WHISPER_AHEADS_LARGE_V1,
+        WHISPER_AHEADS_LARGE_V2,
+        WHISPER_AHEADS_LARGE_V3,
+    };
+
+    typedef struct whisper_ahead {
+        int n_text_layer;
+        int n_head;
+    } whisper_ahead;
+
+    typedef struct whisper_aheads {
+        size_t n_heads;
+        const whisper_ahead * heads;
+    } whisper_aheads;
+
     struct whisper_context_params {
         bool  use_gpu;
         int   gpu_device;  // CUDA device
+
+        // [EXPERIMENTAL] Token-level timestamps with DTW
+        bool dtw_token_timestamps;
+        enum whisper_alignment_heads_preset dtw_aheads_preset;
+
+        int dtw_n_top;
+        struct whisper_aheads dtw_aheads;
+
+        size_t dtw_mem_size; // TODO: remove
     };
 
     typedef struct whisper_token_data {
@@ -103,6 +139,11 @@ extern "C" {
         int64_t t0;        // start time of the token
         int64_t t1;        //   end time of the token
 
+        // [EXPERIMENTAL] Token-level timestamps with DTW
+        // do not use if you haven't computed token-level timestamps with dtw
+        // Roughly corresponds to the moment in audio in which the token was output
+        int64_t t_dtw;
+
         float vlen;        // voice length of the token
     } whisper_token_data;