} \
} while (0)
-//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 8
#define WHISPER_MAX_NODES 4096
int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures
+ // number of decoders for which we have constructed the KV cache
+ int32_t kv_self_n_dec = 0;
+
// unified self-attention KV cache for all decoders
whisper_kv_cache kv_self;
struct ggml_tensor * Q =
ggml_permute(ctx0,
- ggml_cpy(ctx0,
- Qcur,
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
0, 2, 1, 3);
if (wctx.params.flash_attn) {
} else {
struct ggml_tensor * K =
ggml_permute(ctx0,
- ggml_cpy(ctx0,
- Kcur,
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
+ ggml_cast(ctx0,
+ ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
+ wctx.itype),
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
struct ggml_tensor * V =
- ggml_cpy(ctx0,
+ ggml_cast(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
Vcur,
n_state_head, n_head, n_ctx),
1, 2, 0, 3),
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
- );
+ wctx.itype);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- cur = ggml_cpy(ctx0,
- KQV_merged,
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
}
}
layer.mlp_ln_b);
}
-#ifdef WHISPER_USE_FLASH_FF
- cur = ggml_flash_ff(ctx0,
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
-#else
// fully connected
cur = ggml_mul_mat(ctx0,
layer.mlp_0_w,
cur);
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
-#endif
}
inpL = ggml_add(ctx0, cur, inpFF);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- cur = ggml_cpy(ctx0,
- KQV_merged,
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
}
}
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- cur = ggml_cpy(ctx0,
- KQV_merged,
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
}
}
whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
}
- // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
- // in theory, there can be a case where this is not enough, but in practice it should always be enough
- const int factor = 3;
-
+ // at this point, we don't know yet how many decoders will be used
+ // later during decoding, if more decoders are used, we will recreate the KV cache respectively
+ state->kv_self_n_dec = 1;
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer,
- GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
whisper_free_state(state);
return nullptr;
}
WHISPER_LOG_DEBUG("\n\n");
+ // recreate the KV cache if the number of decoders has changed
+ if (state->kv_self_n_dec < n_decoders_cur) {
+ WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
+
+ whisper_kv_cache_free(state->kv_self);
+
+ // overallocate to workaround KV cache fragmentation issues
+ const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
+
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
+ ctx->model.hparams.n_text_state,
+ ctx->model.hparams.n_text_layer,
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
+ whisper_free_state(state);
+ return -7;
+ }
+
+ state->kv_self_n_dec = n_decoders_cur;
+ }
+
whisper_kv_cache_clear(state->kv_self);
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
- return -7;
+ return -8;
}
{
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
- return -8;
+ return -9;
}
const int64_t t_start_sample_us = ggml_time_us();