const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_head_v = hparams.n_embd_head_v;
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
cb(q, "q", il);
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
- ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
- ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
+ ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
+ ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
0);
cb(v, "v", il);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
- cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
+ cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
- cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
+ cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);
}