} \
} while (0)
-//#define WHISPER_USE_FLASH_ATTN
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 8
#define WHISPER_MAX_NODES 4096
// ------
-#ifdef WHISPER_USE_FLASH_ATTN
- struct ggml_tensor * Q =
- ggml_permute(ctx0,
- ggml_cpy(ctx0,
- Qcur,
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
- 0, 2, 1, 3);
-
- struct ggml_tensor * K =
- ggml_permute(ctx0,
- ggml_cpy(ctx0,
- Kcur,
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
- 0, 2, 1, 3);
-
- struct ggml_tensor * V =
- ggml_cpy(ctx0,
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- Vcur,
- n_state/n_head, n_head, n_ctx),
- 1, 2, 0, 3),
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head));
-
- struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
-#else
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
- struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
-
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled);
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
struct ggml_tensor * V =
ggml_cpy(ctx0,
);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
-#endif
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cpy(ctx0,
ggml_set_name(KQ_mask, "KQ_mask");
ggml_set_input(KQ_mask);
+ struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
+
// token encoding + position encoding
struct ggml_tensor * cur =
ggml_add(ctx0,
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
- //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
-
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
- struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask);
-
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask_f16, 1.0f, 0.0f);
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,