]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : fix excessive memory usage (#2443)
authorGeorgi Gerganov <redacted>
Sat, 5 Oct 2024 09:36:40 +0000 (12:36 +0300)
committerGitHub <redacted>
Sat, 5 Oct 2024 09:36:40 +0000 (12:36 +0300)
* whisper : fix KV cache allocation

* whisper : reduce memory overhead from unused input tensors

src/whisper.cpp

index 9c7c66b01a3e05320bf2c676df123672d0323e19..101c4f5b170f48691c2e2b34f9087dfa38b145de 100644 (file)
@@ -163,7 +163,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
         } \
     } while (0)
 
-//#define WHISPER_USE_FLASH_FF
 #define WHISPER_MAX_DECODERS 8
 #define WHISPER_MAX_NODES 4096
 
@@ -817,6 +816,9 @@ struct whisper_state {
     int32_t n_fail_p = 0; // number of logprob threshold failures
     int32_t n_fail_h = 0; // number of entropy threshold failures
 
+    // number of decoders for which we have constructed the KV cache
+    int32_t kv_self_n_dec = 0;
+
     // unified self-attention KV cache for all decoders
     whisper_kv_cache kv_self;
 
@@ -2096,9 +2098,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
 
             struct ggml_tensor * Q =
                 ggml_permute(ctx0,
-                        ggml_cpy(ctx0,
-                            Qcur,
-                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
+                        ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
                         0, 2, 1, 3);
 
             if (wctx.params.flash_attn) {
@@ -2125,9 +2125,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
             } else {
                 struct ggml_tensor * K =
                     ggml_permute(ctx0,
-                            ggml_cpy(ctx0,
-                                Kcur,
-                                ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
+                            ggml_cast(ctx0,
+                                ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
+                                wctx.itype),
                             0, 2, 1, 3);
 
                 // K * Q
@@ -2136,22 +2136,19 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
                 struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
 
                 struct ggml_tensor * V =
-                    ggml_cpy(ctx0,
+                    ggml_cast(ctx0,
                             ggml_permute(ctx0,
                                 ggml_reshape_3d(ctx0,
                                     Vcur,
                                     n_state_head, n_head, n_ctx),
                                 1, 2, 0, 3),
-                            ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
-                            );
+                            wctx.itype);
 
                 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
                 struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-                cur = ggml_cpy(ctx0,
-                        KQV_merged,
-                        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+                cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
             }
         }
 
@@ -2181,11 +2178,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
                         layer.mlp_ln_b);
             }
 
-#ifdef WHISPER_USE_FLASH_FF
-            cur = ggml_flash_ff(ctx0,
-                    ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
-                    layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
-#else
             // fully connected
             cur = ggml_mul_mat(ctx0,
                     layer.mlp_0_w,
@@ -2202,7 +2194,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
                     cur);
 
             cur = ggml_add(ctx0, cur, layer.mlp_1_b);
-#endif
         }
 
         inpL = ggml_add(ctx0, cur, inpFF);
@@ -2578,9 +2569,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
                 struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-                cur = ggml_cpy(ctx0,
-                        KQV_merged,
-                        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+                cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
             }
         }
 
@@ -2687,9 +2676,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
                 struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
-                cur = ggml_cpy(ctx0,
-                        KQV_merged,
-                        ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+                cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
             }
         }
 
@@ -3403,14 +3390,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
         whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
     }
 
-    // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
-    // in theory, there can be a case where this is not enough, but in practice it should always be enough
-    const int factor = 3;
-
+    // at this point, we don't know yet how many decoders will be used
+    // later during decoding, if more decoders are used, we will recreate the KV cache respectively
+    state->kv_self_n_dec = 1;
     if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
                 ctx->model.hparams.n_text_state,
                 ctx->model.hparams.n_text_layer,
-                GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
+                GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
         WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
         whisper_free_state(state);
         return nullptr;
@@ -5775,13 +5761,34 @@ int whisper_full_with_state(
                 }
                 WHISPER_LOG_DEBUG("\n\n");
 
+                // recreate the KV cache if the number of decoders has changed
+                if (state->kv_self_n_dec < n_decoders_cur) {
+                    WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
+
+                    whisper_kv_cache_free(state->kv_self);
+
+                    // overallocate to workaround KV cache fragmentation issues
+                    const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
+
+                    if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
+                                ctx->model.hparams.n_text_state,
+                                ctx->model.hparams.n_text_layer,
+                                GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
+                        WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
+                        whisper_free_state(state);
+                        return -7;
+                    }
+
+                    state->kv_self_n_dec = n_decoders_cur;
+                }
+
                 whisper_kv_cache_clear(state->kv_self);
 
                 whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
 
                 if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
                     WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
-                    return -7;
+                    return -8;
                 }
 
                 {
@@ -6081,7 +6088,7 @@ int whisper_full_with_state(
 
                     if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
                         WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
-                        return -8;
+                        return -9;
                     }
 
                     const int64_t t_start_sample_us = ggml_time_us();