From: Georgi Gerganov Date: Wed, 15 May 2024 06:38:19 +0000 (+0300) Subject: whisper : use flash attention (whisper/2152) X-Git-Tag: upstream/0.0.1642~695 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=da79cd67548704dd5bead501b8ec2b9b265d0bbb;p=pkg%2Fggml%2Fsources%2Fggml whisper : use flash attention (whisper/2152) * whisper : use flash attention in the encoder * whisper : add kv_pad * whisper : remove extra backend instance (huh?) * whisper : use FA for cross-attention * whisper : use FA for self-attention * whisper : simplify encoder FA * whisper : add flash_attn runtime parameter * scripts : add bench log * scripts : add M1 Pro bench log --- diff --git a/examples/whisper/main.cpp b/examples/whisper/main.cpp index d11c1c3f..45eb17fe 100644 --- a/examples/whisper/main.cpp +++ b/examples/whisper/main.cpp @@ -70,6 +70,7 @@ struct whisper_params { bool no_timestamps = false; bool log_score = false; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string prompt; @@ -168,7 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } else if ( arg == "--grammar") { params.grammar = argv[++i]; } else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; } else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } @@ -234,6 +236,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); @@ -977,7 +980,9 @@ int main(int argc, char ** argv) { // whisper init struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) { cparams.dtw_token_timestamps = true; diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index ff4223da..84aec823 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -809,14 +809,15 @@ struct whisper_state { // shared between all decoders whisper_kv_cache kv_cross; + // padded buffer for flash-attention + whisper_kv_cache kv_pad; + whisper_mel mel; whisper_batch batch; whisper_decoder decoders[WHISPER_MAX_DECODERS]; - ggml_backend_t backend = nullptr; - // ggml-alloc: // - stores meta info about the intermediate tensors into the `meta` buffers // - stores the actual tensor data into the `data` buffers @@ -902,14 +903,12 @@ static void read_safe(whisper_model_loader * loader, T & dest) { } static bool kv_cache_init( - const struct whisper_hparams & hparams, struct whisper_kv_cache & cache, ggml_backend_t backend, ggml_type wtype, + int64_t n_text_state, + int64_t n_text_layer, int n_ctx) { - const int64_t n_text_state = hparams.n_text_state; - const int64_t n_text_layer = hparams.n_text_layer; - const int64_t n_mem = n_text_layer*n_ctx; const int64_t n_elements = n_text_state*n_mem; @@ -941,6 +940,8 @@ static bool kv_cache_init( return false; } + ggml_backend_buffer_clear(cache.buffer, 0); + return true; } @@ -1068,6 +1069,26 @@ static void whisper_kv_cache_seq_cp( } } +static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) { + if (!wctx.params.flash_attn) { + return 1u; + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(wctx.backend)) { + return 32u; + } +#endif + +#ifdef GGML_USE_CUDA + if (ggml_backend_is_cuda(wctx.backend)) { + return 256u; + } +#endif + + return 1u; +} + // [EXPERIMENTAL] Token-level timestamps with DTW static bool aheads_masks_init( const whisper_context_params & cparams, @@ -1872,6 +1893,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder( const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; + const int n_state_head = n_state/n_head; + + auto & kv_pad = wstate.kv_pad; + + WHISPER_ASSERT(!!kv_pad.ctx); + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_encode.meta.size(), /*.mem_buffer =*/ wstate.alloc_encode.meta.data(), @@ -1884,7 +1913,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); - const float KQscale = 1.0f/sqrtf(float(n_state)/n_head); + const float KQscale = 1.0f/sqrtf(float(n_state_head)); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1934,14 +1963,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder( Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b); - //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25)); + //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25)); // note: no bias for Key struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); - //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25)); + //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25)); struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.attn_v_w, @@ -1955,38 +1984,61 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_permute(ctx0, ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)), 0, 2, 1, 3); - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + if (wctx.params.flash_attn) { + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0))); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_pad.k, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.k)*n_state, + ggml_element_size(kv_pad.k)*n_state_head, + 0); - struct ggml_tensor * V = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state/n_head, n_head, n_ctx), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head) - ); + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_pad.v, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.v)*n_state, + ggml_element_size(kv_pad.v)*n_state_head, + 0); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx); + } else { + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + } } // projection @@ -2085,6 +2137,10 @@ static struct ggml_cgraph * whisper_build_graph_cross( const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; + const int n_state_head = n_state/n_head; + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_cross.meta.size(), /*.mem_buffer =*/ wstate.alloc_cross.meta.data(), @@ -2097,18 +2153,18 @@ static struct ggml_cgraph * whisper_build_graph_cross( struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); - const float Kscale = pow(float(n_state) / n_head, -0.25); + const float Kscale = pow(float(n_state_head), -0.25); for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; - struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, + struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, layer.cross_attn_k_w, cur); Kcross = ggml_scale(ctx0, Kcross, Kscale); - struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, + struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, cur); @@ -2116,15 +2172,25 @@ static struct ggml_cgraph * whisper_build_graph_cross( Vcross, layer.cross_attn_v_b); - Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + struct ggml_tensor * k; + struct ggml_tensor * v; - struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, - n_state*n_ctx, - (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + if (wctx.params.flash_attn) { + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad)); - struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, - ( n_ctx)*ggml_element_size(wstate.kv_cross.v), - (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); + v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad)); + } else { + Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + + v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, + ( n_ctx)*ggml_element_size(wstate.kv_cross.v), + (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v)); @@ -2195,7 +2261,7 @@ static bool whisper_encode_internal( } if (!whisper_encode_external(wstate)) { - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } else { @@ -2218,7 +2284,7 @@ static bool whisper_encode_internal( return false; } - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } @@ -2234,7 +2300,7 @@ static bool whisper_encode_internal( return false; } - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } @@ -2263,11 +2329,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder( const int n_head = hparams.n_text_head; const int n_layer = hparams.n_text_layer; + const int n_state_head = n_state/n_head; + const int n_tokens = batch.n_tokens; const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; + const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256); + + const int32_t n_kv = worst_case ? n_ctx : kv_self.n; + const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); @@ -2289,12 +2359,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_set_name(position, "position"); ggml_set_input(position); - const float KQscale = pow(float(n_state)/n_head, -0.25); + const float KQscale = pow(float(n_state_head), -0.25); - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1); ggml_set_name(KQ_mask, "KQ_mask"); ggml_set_input(KQ_mask); + struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); + // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -2350,12 +2422,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Vcur, layer.attn_v_b); - Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); + struct ggml_tensor * k; + struct ggml_tensor * v; + + if (wctx.params.flash_attn) { + k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, + (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + + v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state, + (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head)); + } else { + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); + k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, + (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + + v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); @@ -2365,35 +2450,48 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), 0, 2, 1, 3); struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_state/n_head, n_kv, n_head, + n_state_head, n_kv, n_head, ggml_element_size(kv_self.k)*n_state, - ggml_element_size(kv_self.k)*n_state/n_head, + ggml_element_size(kv_self.k)*n_state_head, ggml_element_size(kv_self.k)*n_state*n_ctx*il); - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + if (wctx.params.flash_attn) { + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_state_head, n_kv, n_head, + ggml_element_size(kv_self.v)*n_state, + ggml_element_size(kv_self.v)*n_state_head, + ggml_element_size(kv_self.v)*n_state*n_ctx*il); + + cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); + } else { + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_state/n_head, n_head, - n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, - n_ctx*ggml_element_size(kv_self.v)*n_state*il); + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_state_head, n_head, + n_ctx*ggml_element_size(kv_self.v), + n_ctx*ggml_element_size(kv_self.v)*n_state_head, + n_ctx*ggml_element_size(kv_self.v)*n_state*il); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + } } // projection @@ -2432,80 +2530,77 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Qcur, layer.cross_attn_q_b); - Qcur = ggml_scale(ctx0, Qcur, KQscale); - - // Kcross is already scaled - struct ggml_tensor * Kcross = - ggml_view_3d(ctx0, wstate.kv_cross.k, - n_state/n_head, n_audio_ctx, n_head, - ggml_element_size(wstate.kv_cross.k)*n_state, - ggml_element_size(wstate.kv_cross.k)*n_state/n_head, - ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); - - //struct ggml_tensor * Vcross = - // ggml_reshape_3d(ctx0, - // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state), - // n_state/n_head, n_head, n_audio_ctx); - - //struct ggml_tensor * V_trans = - // ggml_cpy(ctx0, - // ggml_permute(ctx0, Vcross, 1, 2, 0, 3), - // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head)); - - struct ggml_tensor * V = - ggml_view_3d(ctx0, wstate.kv_cross.v, - n_audio_ctx, n_state/n_head, n_head, - n_audio_ctx*ggml_element_size(wstate.kv_cross.v), - n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, - n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); - - // ------ - struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), 0, 2, 1, 3); - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); - - //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctx0, - // KQ, - // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) - // ); + if (wctx.params.flash_attn) { + struct ggml_tensor * Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.k)*n_state, + ggml_element_size(wstate.kv_cross.k)*n_state_head, + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il); - // no masking for cross-attention - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + struct ggml_tensor * Vcross = + ggml_view_3d(ctx0, wstate.kv_cross.v, + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.v)*n_state, + ggml_element_size(wstate.kv_cross.v)*n_state_head, + ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il); - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); + cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f); - // [EXPERIMENTAL] Token-level timestamps with DTW - if (wctx.params.dtw_token_timestamps) { - if (wstate.aheads_masks.m[il] != nullptr) { - struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); - aheads_KQs = ggml_transpose(ctx0, aheads_KQs); - aheads_KQs = ggml_cont(ctx0, aheads_KQs); - aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs); - aheads_KQs = ggml_transpose(ctx0, aheads_KQs); - aheads_KQs = ggml_cont(ctx0, aheads_KQs); - aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]); - if (aheads_cross_QKs == NULL) { - aheads_cross_QKs = aheads_KQs; - } else { - aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs); + cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); + } else { + struct ggml_tensor * Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state_head, n_audio_ctx, n_head, + ggml_element_size(wstate.kv_cross.k)*n_state, + ggml_element_size(wstate.kv_cross.k)*n_state_head, + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); + + struct ggml_tensor * Vcross = + ggml_view_3d(ctx0, wstate.kv_cross.v, + n_audio_ctx, n_state_head, n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v), + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); + + // ------ + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + // [EXPERIMENTAL] Token-level timestamps with DTW + if (wctx.params.dtw_token_timestamps) { + if (wstate.aheads_masks.m[il] != nullptr) { + struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]); + if (aheads_cross_QKs == NULL) { + aheads_cross_QKs = aheads_KQs; + } else { + aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs); + } } } - } - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - // cur = KQV_merged.contiguous().view(n_state, n_tokens) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + } } // projection @@ -2638,7 +2733,9 @@ static bool whisper_decode_internal( return false; } - kv_self.n = whisper_kv_cache_cell_max(kv_self); + const uint32_t pad = whisper_kv_cache_get_padding(wctx); + kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad))); + //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); } @@ -2672,9 +2769,10 @@ static bool whisper_decode_internal( struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask"); auto & kv_self = wstate.kv_self; - const int32_t n_kv = kv_self.n; - wstate.inp_mask.resize(n_kv*n_tokens); + const int32_t n_kv = kv_self.n; + + wstate.inp_mask.resize(ggml_nelements(KQ_mask)); float * data = wstate.inp_mask.data(); memset(data, 0, ggml_nbytes(KQ_mask)); @@ -2690,6 +2788,12 @@ static bool whisper_decode_internal( } } } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } } ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float)); @@ -2697,7 +2801,7 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; - if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { return false; } } @@ -3144,18 +3248,14 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_state * state = new whisper_state; - state->backend = whisper_backend_init(ctx->params); - if (!state->backend) { - WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); - whisper_free_state(state); - return nullptr; - } - // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx // in theory, there can be a case where this is not enough, but in practice it should always be enough const int factor = 3; - if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) { + if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3166,7 +3266,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); } - if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) { + if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3177,6 +3280,20 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); } + if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype, + ctx->model.hparams.n_audio_state, + 1, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); + whisper_free_state(state); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v); + WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6); + } + // [EXPERIMENTAL] Token-level timestamps with DTW if (ctx->params.dtw_token_timestamps) { if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) { @@ -3347,6 +3464,7 @@ int whisper_ctx_init_openvino_encoder( struct whisper_context_params whisper_context_default_params() { struct whisper_context_params result = { /*.use_gpu =*/ true, + /*.flash_attn =*/ false, /*.gpu_device =*/ 0, /*.dtw_token_timestamps =*/ false, @@ -3445,6 +3563,16 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) { ggml_time_init(); + if (params.flash_attn && params.dtw_token_timestamps) { + WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__); + params.dtw_token_timestamps = false; + } + + WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn); + WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps); + whisper_context * ctx = new whisper_context; ctx->params = params; @@ -3533,6 +3661,7 @@ void whisper_free_state(struct whisper_state * state) { if (state) { kv_cache_free(state->kv_self); kv_cache_free(state->kv_cross); + kv_cache_free(state->kv_pad); #ifdef WHISPER_USE_COREML if (state->ctx_coreml != nullptr) { @@ -3555,8 +3684,6 @@ void whisper_free_state(struct whisper_state * state) { ggml_gallocr_free(state->alloc_cross.alloc); ggml_gallocr_free(state->alloc_decode.alloc); - ggml_backend_free(state->backend); - // [EXPERIMENTAL] Token-level timestamps with DTW aheads_masks_free(state->aheads_masks); diff --git a/examples/whisper/whisper.h b/examples/whisper/whisper.h index 6a875d3b..9c7c58d8 100644 --- a/examples/whisper/whisper.h +++ b/examples/whisper/whisper.h @@ -113,6 +113,7 @@ extern "C" { struct whisper_context_params { bool use_gpu; + bool flash_attn; int gpu_device; // CUDA device // [EXPERIMENTAL] Token-level timestamps with DTW