]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model: (qwen3next) correct vectorized key_gdiff calculation (#19324)
authorXuan-Son Nguyen <redacted>
Wed, 4 Feb 2026 12:09:58 +0000 (13:09 +0100)
committerGitHub <redacted>
Wed, 4 Feb 2026 12:09:58 +0000 (13:09 +0100)
* model: (qwen3next) correct vectorized key_gdiff calculation

* move transpose to outside of loop

src/models/qwen3next.cpp

index 57b6659baf0869d8ef28a5a5f1e273b3cc3e1a05..06d946c5fa54087775a9c7d081287ed7857f4062 100644 (file)
@@ -265,9 +265,15 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
     cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
 
     ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
+    ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, 
+                                                 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
+
+    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
     cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
 
+    ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); 
+    cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
+
 
     // state to be updated per chunk
     ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
@@ -322,9 +328,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
             : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
 
         // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
+        ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
         //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
 
         // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
         ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));