]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : faster beam_search sampling via reduced KV cache copies (#1243)
authorbobqianic <redacted>
Sun, 10 Sep 2023 13:04:27 +0000 (21:04 +0800)
committerGitHub <redacted>
Sun, 10 Sep 2023 13:04:27 +0000 (16:04 +0300)
* Faster `beam_search` sampling

Refine the KV cache update logic for more intelligent and efficient updating.

* Faster `whisper_sample_token_topk`

* Update whisper.cpp

* Update whisper.cpp

* Update whisper.cpp

* Reduce `memory allocation`

* Add `pointer swapping`

* Fixed some bugs

* Update whisper.cpp

* Apply suggestions from code review

* Updated the logic for determining `two-copy`

* Updated the logic for determining `two-copy` v2

* whisper : add debug logs + coding style

---------

Co-authored-by: Georgi Gerganov <redacted>
whisper.cpp

index 3192fbc640425b7d2714cc587971cb45e1d7e34c..5c14b43efdcc9bc2796891cf22eb43604d60c648 100644 (file)
@@ -18,6 +18,7 @@
 #include <cstring>
 #include <fstream>
 #include <map>
+#include <set>
 #include <string>
 #include <thread>
 #include <vector>
@@ -537,6 +538,7 @@ struct whisper_kv_cache {
 
     struct ggml_context * ctx;
 
+    // buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init)
     std::vector<uint8_t> buf;
 
     int n; // number of tokens currently in the cache
@@ -602,7 +604,7 @@ struct whisper_sequence {
 
 // TAGS: WHISPER_DECODER_INIT
 struct whisper_decoder {
-    // each decoders keeps its own KV-cache
+    // each decoder keeps its own KV-cache
     whisper_kv_cache kv_self;
 
     // the currently generated sequence of tokens
@@ -622,6 +624,24 @@ struct whisper_decoder {
     std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
 };
 
+// replace std::pair by using customized pair struct (reason: std::pair is very slow)
+template<typename A, typename B>
+struct whisper_pair {
+    A first;
+    B second;
+
+    // Define a constructor that takes two arguments.
+    whisper_pair(const A& a, const B& b) : first(a), second(b) {}
+    // Define a constructor that takes no argument.
+    whisper_pair() : first(A()), second(B()) {}
+};
+
+// beam-search helpers
+struct kv_buf {
+    std::vector<uint8_t> k;
+    std::vector<uint8_t> v;
+};
+
 struct whisper_state {
     int64_t t_sample_us = 0;
     int64_t t_encode_us = 0;
@@ -641,6 +661,9 @@ struct whisper_state {
 
     whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
 
+    // buffer for swapping KV caches between decoders during beam-search
+    std::vector<kv_buf> kv_swap_bufs;
+
     // memory buffers used by encode / decode contexts
     std::vector<uint8_t> buf_compute;
     std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
@@ -655,7 +678,7 @@ struct whisper_state {
     std::vector<whisper_token>   prompt_past;
 
     // work container used to avoid memory allocations
-    std::vector<std::pair<double, whisper_vocab::id>> logits_id;
+    std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
 
     mutable std::mt19937 rng; // used for sampling at t > 0.0
 
@@ -3975,17 +3998,21 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
 
     auto & logits_id = state.logits_id;
 
-    logits_id.clear();
+    logits_id.resize(n_logits);
     for (int i = 0; i < n_logits; ++i) {
-        logits_id.push_back({ logits[i], i });
+        logits_id[i].first = logits[i];
+        logits_id[i].second = i;
     }
 
-    std::partial_sort(
-            logits_id.begin(),
-            logits_id.begin() + k, logits_id.end(),
-            [](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
-                return a.first > b.first;
-            });
+    {
+        using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
+        std::partial_sort(
+                logits_id.begin(),
+                logits_id.begin() + k, logits_id.end(),
+                [](const pair_type & a, const pair_type & b) {
+            return a.first > b.first;
+        });
+    }
 
     std::vector<whisper_token_data> result;
     result.reserve(k);
@@ -4080,6 +4107,115 @@ static void whisper_sequence_score(
     }
 }
 
+static bool whisper_kv_swap_fast(
+                   std::vector<int> & view,
+                    whisper_decoder   src[],
+                std::vector<kv_buf> & kv_swap_bufs,
+                          const int & n_decoders) {
+    WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
+
+    // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
+    std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
+
+    // (buffer->decoder or decoder->decoder)
+    std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
+
+    // (decoder<->decoder)
+    std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
+    std::vector<whisper_pair<int, int>> p_swap_vec;
+    p_swap_vec.reserve(n_decoders);
+
+    // see https://github.com/ggerganov/whisper.cpp/wiki
+    for (int i = 0; i < n_decoders; i++) {
+        // zero-copy (no modification)
+        if (i == view[i] || view[i] < 0) {
+            continue;
+        }
+
+        bool is_one_copy = true;
+        // since we modify data sequentially, we only consider decoder indices after current index
+        for (int j = i + 1; j < n_decoders; j++) {
+            if (i == view[j]) {
+                // detect symmetric diagram
+                if (j == view[i]) {
+                    p_swap_set.insert(i);
+                    p_swap_set.insert(j);
+                    p_swap_vec.emplace_back(i, j);
+                } else {
+                    two_copy.insert(i);
+                    is_one_copy = false;
+                }
+                break;
+            }
+        }
+        if (is_one_copy) {
+            one_copy.insert(i);
+        }
+    }
+
+    kv_swap_bufs.resize(n_decoders);
+
+    for (int i = 0; i < n_decoders; i++) {
+        kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
+        kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
+    }
+
+    for (auto & i : two_copy) {
+        // make a copy of KV caches
+        WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
+        memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
+        memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
+    }
+
+    // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
+    for (auto & i : two_copy) {
+        // skip the decoder indices that require pointer swapping
+        if (p_swap_set.find(i) != p_swap_set.end()) {
+            continue;
+        }
+
+        if (two_copy.find(view[i]) != two_copy.end()) {
+            // modify KV caches of decoder using data from kv_swap_bufs
+            WHISPER_PRINT_DEBUG("%s: two-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
+            memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
+            memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+        } else {
+            // modify KV caches of decoder using data from correspond decoder KV caches directly
+            WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
+            memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
+            memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+        }
+    }
+
+    // then modify one-copy decoder KV caches
+    for (auto & i : one_copy) {
+        // skip the decoder indices that require pointer swapping
+        if (p_swap_set.find(i) != p_swap_set.end()) {
+            continue;
+        }
+
+        if (two_copy.find(view[i]) != two_copy.end()) {
+            // modify KV caches of decoder using data from kv_swap_bufs
+            WHISPER_PRINT_DEBUG("%s: one-copy decoder using   swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
+            memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
+            memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
+        } else {
+            // modify KV caches of decoder using data from correspond decoder KV caches directly
+            WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers:      %d  -> %d\n", __func__, view[i], i);
+            memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
+            memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
+        }
+    }
+
+    // swap the pointers
+    for (auto & i : p_swap_vec) {
+        WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
+        std::swap(src[i.first].kv_self, src[i.second].kv_self);
+    }
+
+    return true;
+}
+
 int whisper_full_with_state(
         struct whisper_context * ctx,
           struct whisper_state * state,
@@ -4243,14 +4379,6 @@ int whisper_full_with_state(
     std::vector<whisper_token> prompt;
     prompt.reserve(whisper_n_text_ctx(ctx));
 
-    // beam-search helpers
-    struct kv_buf {
-        std::vector<uint8_t> k;
-        std::vector<uint8_t> v;
-    };
-
-    std::vector<kv_buf> kv_bufs;
-
     struct beam_candidate {
         int decoder_idx;
         int seek_delta;
@@ -4399,23 +4527,7 @@ int whisper_full_with_state(
             for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
                 const int64_t t_start_sample_us = ggml_time_us();
 
-                // store the KV caches of all decoders when doing beam-search
                 if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
-                    kv_bufs.resize(n_decoders_cur);
-                    for (int j = 0; j < n_decoders_cur; ++j) {
-                        auto & decoder = state->decoders[j];
-
-                        if (decoder.completed || decoder.failed) {
-                            continue;
-                        }
-
-                        kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
-                        kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
-
-                        memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
-                        memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
-                    }
-
                     beam_candidates.clear();
                 }
 
@@ -4463,6 +4575,7 @@ int whisper_full_with_state(
                     });
 
                     uint32_t cur_c = 0;
+                    std::vector<int> decoder_idx(n_decoders_cur, -1);
 
                     for (int j = 0; j < n_decoders_cur; ++j) {
                         auto & decoder = state->decoders[j];
@@ -4481,12 +4594,13 @@ int whisper_full_with_state(
                         decoder.seek_delta = cur.seek_delta;
                         decoder.has_ts     = cur.has_ts;
 
-                        memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
-                        memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
-
+                        decoder_idx[j] = cur.decoder_idx;
                         WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
                                 __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
                     }
+
+                    // update KV caches
+                    whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
                 }
 
                 // update the decoder state