Vcur,
n_state/n_head, n_head, n_ctx),
1, 2, 0, 3),
- ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
- );
+ ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head));
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
#else
wstate.use_buf(ctx0, -1);
- //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
- //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
- struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
- struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx));
+ Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
+
+ struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
+ struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
+ ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
+ (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
- layer.attn_v_w,
- cur);
-
- Vcur = ggml_add(ctx0,
- ggml_repeat(ctx0,
- layer.attn_v_b,
- Vcur),
- Vcur);
-
// store key and value to memory
{
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
+ layer.attn_v_w,
+ cur);
+
+ Vcur = ggml_add(ctx0,
+ ggml_repeat(ctx0,
+ layer.attn_v_b,
+ Vcur),
+ Vcur);
+
+ Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
+
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
- struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
+ ( n_ctx)*ggml_element_size(kv_self.v),
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
- struct ggml_tensor * V_trans =
- ggml_cpy(ctx0,
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
- n_state/n_head, n_head, n_past + N),
- 1, 2, 0, 3),
- ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_state/n_head, n_head));
+ struct ggml_tensor * V =
+ ggml_view_3d(ctx0, kv_self.v,
+ n_past + N, n_state/n_head, n_head,
+ n_ctx*ggml_element_size(kv_self.v),
+ n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
+ il*n_ctx*ggml_element_size(kv_self.v)*n_state);
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
n_state/n_head, n_head, M);
- struct ggml_tensor * Vcross =
- ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
- n_state/n_head, n_head, M);
+ //struct ggml_tensor * Vcross =
+ // ggml_reshape_3d(ctx0,
+ // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
+ // n_state/n_head, n_head, M);
- struct ggml_tensor * V_trans =
- ggml_cpy(ctx0,
- ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
- ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
+ //struct ggml_tensor * V_trans =
+ // ggml_cpy(ctx0,
+ // ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
+ // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
+
+ struct ggml_tensor * V =
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
+ M, n_state/n_head, n_head,
+ M*ggml_element_size(wstate.kv_cross.v),
+ M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
+ il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
// ------
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);