]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
models : fix qwen3.5 beta/gate shapes (#19730)
authorGeorgi Gerganov <redacted>
Thu, 19 Feb 2026 13:19:53 +0000 (15:19 +0200)
committerGitHub <redacted>
Thu, 19 Feb 2026 13:19:53 +0000 (15:19 +0200)
* models : fix qwen3.5 beta/gate shapes

* cont : avoid extra reshapes

src/models/kimi-linear.cpp
src/models/qwen35.cpp
src/models/qwen35moe.cpp

index 8173d894ef2ef2dfda2aa2667b8ca38ac0522923..4d6bb83c14215dc6041a2947bab8e27d92ecfd96 100644 (file)
@@ -149,17 +149,19 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
             g1 = ggml_mul(ctx0, g1, A);
             cb(g1, "kda_g1", il);
 
+            g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
+
             // Compute beta (mixing coefficient)
             ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur);
-            beta = ggml_reshape_4d(ctx0, beta, n_head, 1, n_seq_tokens, n_seqs);
+            beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs);
             cb(beta, "kda_beta", il);
 
+            beta = ggml_sigmoid(ctx0, beta);
+
             // Reshape for KDA recurrence
             // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs}
             cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
 
-            g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
-
             // Get SSM state and compute KDA recurrence using ggml_kda_scan
             ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
             ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
@@ -169,10 +171,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
 
             Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm);
             Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm);
-            beta = ggml_sigmoid(ctx0, beta);
-
-            beta = ggml_reshape_4d(ctx0, beta,        1, n_head, n_seq_tokens, n_seqs);
-            g1   = ggml_reshape_4d(ctx0, g1,   head_dim, n_head, n_seq_tokens, n_seqs);
 
             // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
             std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
index 7e1749b2c818a36f424f6a46b58bd9cd16450e44..56eefd7de27ba15e79db0a96b8c9cbf159d03c38 100644 (file)
@@ -216,7 +216,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
     ggml_tensor * z         = qkvz.second;
 
     ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
-    beta  = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
+    beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
     cb(beta, "beta", il);
 
     beta = ggml_sigmoid(ctx0, beta);
@@ -232,6 +232,8 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
     ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
     cb(gate, "gate", il);
 
+    gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
+
     // Get convolution states from cache
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
index e12a5dea737689df358d653fde5774a95e8c4739..c7295e3364f5e3d4eee6db78ba63b7a7a0e9d0e7 100644 (file)
@@ -216,7 +216,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
     ggml_tensor * z         = qkvz.second;
 
     ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
-    beta  = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
+    beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
     cb(beta, "beta", il);
 
     beta = ggml_sigmoid(ctx0, beta);
@@ -232,6 +232,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
     ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
     cb(gate, "gate", il);
 
+    gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
+
     // Get convolution states from cache
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);