std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer;
+ ggml_type wtype; // weight type (FP32 or FP16)
+
whisper_model model;
whisper_vocab vocab;
// for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation
- const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+ wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+ const ggml_type wtype = wctx.wtype;
size_t ctx_size = 0;
// encoder
{
- // TODO: F16 .. maybe not?
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
// decoder
{
- // TODO: F16 .. maybe not?
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
- model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
- model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
+ model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}
// key/value memory for the cross-attention layer
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
- model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
- model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
+ model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}
const size_t memory_size =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * V =
Vcur,
n_state/n_head, n_head, n_ctx),
1, 2, 0, 3),
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
+ ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
// K * Q
// ggml_permute(ctxL,
// ggml_cpy(ctxL,
// Vcur,
- // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
+ // ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
Vcur,
n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
- ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
+ ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
#ifdef USE_FLASH_FF
cur = ggml_flash_ff(ctxL,
- ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
+ ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
// fully connected
// separate key + value memory for each processor
{
- auto & ctx = model.ctx_mem;
+ auto & mctx = model.ctx_mem;
const auto & hparams = model.hparams;
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
- model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
- model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
+ model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
}
// key/value memory for the cross-attention layer
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
- model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
- model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
+ model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
}
}
}