cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
- ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
+ ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
+ 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
+
+ ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
+ ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
+ cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
+
// state to be updated per chunk
ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
- ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
+ ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
//ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
- ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
+ ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));