]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
[Model] Qwen3.5 dense and MoE support (no vision) (#19435)
authorPiotr Wilkin (ilintar) <redacted>
Sun, 8 Feb 2026 23:24:08 +0000 (00:24 +0100)
committerGitHub <redacted>
Sun, 8 Feb 2026 23:24:08 +0000 (00:24 +0100)
* Unified delta net handling

* Remove old methods.

* Refactor and optimize

* Adapt autoregressive version from @ymcki

* Change to decay mask approach

* Fix bad permute

* Qwen 3.5 support

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <redacted>
* Further fixes

* Use inheritance, remove unneeded conts

* Not like this!

* Remove ggml.h explicit import

* Remove transformers, fix the views

* ACTUALLY fix views, make super calls explicit in conversion.

* Fix conversion again

* Remove extra ggml.h imports

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
14 files changed:
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/CMakeLists.txt
src/llama-arch.cpp
src/llama-arch.h
src/llama-context.cpp
src/llama-model.cpp
src/models/delta.cpp [new file with mode: 0644]
src/models/kimi-linear.cpp
src/models/models.h
src/models/qwen3-5.cpp [new file with mode: 0644]
src/models/qwen3-5moe.cpp [new file with mode: 0644]
src/models/qwen3next.cpp

index 843c00a8969fabf3dfe1bb2617c7921c51bf00aa..e64756a74ab4627b730b3b68e093f1a3548cd469 100755 (executable)
@@ -4102,39 +4102,27 @@ class Qwen2MoeModel(TextModel):
         # process the experts separately
         name = name.replace("language_model.", "") # InternVL
 
-        # handle aggregated expert tensors
-        # GGUF stores dimensions reversed from PyTorch, so:
-        # PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
-        # Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
-        # Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
+        # handle pre-packed expert tensors (e.g. Qwen3.5 MoE, Qwen3Next)
+        # HF stores these using nn.Linear convention: [n_expert, out_features, in_features]
+        # This matches the individual expert stacking path below (which stacks
+        # per-expert [out, in] weights into [n_expert, out, in]), so no permute is needed.
         if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
             mapped = f"{name}.weight" if not name.endswith(".weight") else name
-            # Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
-            # Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
-            # Need PyTorch: (128, 2048, 768) [reversed of GGML]
-            # So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
-            permuted = data_torch.permute(0, 2, 1).contiguous()
-            yield from super().modify_tensors(permuted, mapped, bid)
+            # HF: [n_expert, n_embd, n_ff] → GGML: {n_ff, n_embd, n_expert} ✓
+            yield from super().modify_tensors(data_torch, mapped, bid)
             return
 
         if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
-            if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
-                raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
-            split_dim = data_torch.shape[-1] // 2
-            gate = data_torch[..., :split_dim].contiguous()
-            up = data_torch[..., split_dim:].contiguous()
-            # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
-            # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
-            # Need PyTorch: (128, 768, 2048) [reversed of GGML]
-            # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
-            base_name = name.removesuffix(".weight")
-            base = base_name.rsplit('.', 1)[0]
-            mapped_gate = f"{base}.gate_proj.weight"
-            mapped_up = f"{base}.up_proj.weight"
-            perm_gate = gate.permute(0, 2, 1).contiguous()
-            perm_up = up.permute(0, 2, 1).contiguous()
-            yield from super().modify_tensors(perm_gate, mapped_gate, bid)
-            yield from super().modify_tensors(perm_up, mapped_up, bid)
+            # HF: [n_expert, 2*n_ff, n_embd] → split on dim=1
+            n_ff = data_torch.shape[1] // 2
+            gate = data_torch[:, :n_ff, :].contiguous()
+            up = data_torch[:, n_ff:, :].contiguous()
+            # gate/up: [n_expert, n_ff, n_embd] → GGML: {n_embd, n_ff, n_expert} ✓
+            base_name = name.removesuffix(".weight").removesuffix(".gate_up_proj")
+            mapped_gate = f"{base_name}.gate_proj.weight"
+            mapped_up = f"{base_name}.up_proj.weight"
+            yield from super().modify_tensors(gate, mapped_gate, bid)
+            yield from super().modify_tensors(up, mapped_up, bid)
             return
 
         if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
@@ -4344,6 +4332,40 @@ class Qwen3NextModel(Qwen2MoeModel):
             yield from super().modify_tensors(data_torch, name, bid)
 
 
+@ModelBase.register("Qwen3_5ForCausalLM", "Qwen3_5TextForCausalLM")
+class Qwen3_5Model(Qwen3NextModel):
+    model_arch = gguf.MODEL_ARCH.QWEN3_5
+
+    # Stores whichever of in_proj_a/in_proj_b is seen first, keyed by layer
+    _pending_ba: dict[int | None, tuple[str, Tensor]] = {}
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # Handle split in_proj_b + in_proj_a → concatenated SSM_BETA_ALPHA
+        # safetensors sorts alphabetically so in_proj_a arrives before in_proj_b
+        if "in_proj_a.weight" in name or "in_proj_b.weight" in name:
+            which = "a" if "in_proj_a" in name else "b"
+            if bid not in self._pending_ba:
+                self._pending_ba[bid] = (which, data_torch)
+                return
+            prev_which, prev_tensor = self._pending_ba.pop(bid)
+            assert prev_which != which, f"duplicate in_proj_{which} for layer {bid}"
+            b_tensor = prev_tensor if prev_which == "b" else data_torch
+            a_tensor = prev_tensor if prev_which == "a" else data_torch
+            ba_combined = torch.cat([b_tensor, a_tensor], dim=0)
+            yield (self.format_tensor_name(gguf.MODEL_TENSOR.SSM_BETA_ALPHA, bid, ".weight"), ba_combined)
+            return
+        else:
+            # Qwen3Next uses .qkvz tensor, so we use the super to get the other functionalities
+            # (norm correction, A_log to A etc.) for free
+            # Qwen2Moe already does the gate_up conversion properly, just use that
+            yield from super().modify_tensors(data_torch, name, bid)
+
+
+@ModelBase.register("Qwen3_5MoeForCausalLM", "Qwen3_5MoeTextForCausalLM")
+class Qwen3_5MoeModel(Qwen3_5Model):
+    model_arch = gguf.MODEL_ARCH.QWEN3_5_MOE
+
+
 @ModelBase.register("RND1")
 class RND1Model(Qwen2MoeModel):
     model_arch = gguf.MODEL_ARCH.RND1
index 3af4fffe95742c761d5397cc29c7af63a8476408..8a3fab1e1c30b886c0f156e20e7e421b9222f0a4 100644 (file)
@@ -382,6 +382,8 @@ class MODEL_ARCH(IntEnum):
     QWEN3            = auto()
     QWEN3MOE         = auto()
     QWEN3NEXT        = auto()
+    QWEN3_5          = auto()
+    QWEN3_5_MOE      = auto()
     QWEN3VL          = auto()
     QWEN3VLMOE       = auto()
     PHI2             = auto()
@@ -812,6 +814,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.QWEN3:            "qwen3",
     MODEL_ARCH.QWEN3MOE:         "qwen3moe",
     MODEL_ARCH.QWEN3NEXT:        "qwen3next",
+    MODEL_ARCH.QWEN3_5:          "qwen3_5",
+    MODEL_ARCH.QWEN3_5_MOE:      "qwen3_5moe",
     MODEL_ARCH.QWEN3VL:          "qwen3vl",
     MODEL_ARCH.QWEN3VLMOE:       "qwen3vlmoe",
     MODEL_ARCH.PHI2:             "phi2",
@@ -1784,6 +1788,61 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.SSM_BETA_ALPHA,
         MODEL_TENSOR.SSM_OUT
     ],
+    MODEL_ARCH.QWEN3_5: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_POST_NORM,
+        MODEL_TENSOR.ATTN_GATE,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.SSM_A,
+        MODEL_TENSOR.SSM_CONV1D,
+        MODEL_TENSOR.SSM_DT,
+        MODEL_TENSOR.SSM_NORM,
+        MODEL_TENSOR.SSM_IN,
+        MODEL_TENSOR.SSM_BETA_ALPHA,
+        MODEL_TENSOR.SSM_OUT,
+    ],
+    MODEL_ARCH.QWEN3_5_MOE: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_POST_NORM,
+        MODEL_TENSOR.ATTN_GATE,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_GATE_INP_SHEXP,
+        MODEL_TENSOR.FFN_UP_SHEXP,
+        MODEL_TENSOR.FFN_DOWN_SHEXP,
+        MODEL_TENSOR.FFN_GATE_SHEXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.SSM_A,
+        MODEL_TENSOR.SSM_CONV1D,
+        MODEL_TENSOR.SSM_DT,
+        MODEL_TENSOR.SSM_NORM,
+        MODEL_TENSOR.SSM_IN,
+        MODEL_TENSOR.SSM_BETA_ALPHA,
+        MODEL_TENSOR.SSM_OUT,
+    ],
     MODEL_ARCH.QWEN3VL: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 167ade780334cdfee3ebc390e5f99f6411e6d44a..43f32c7b5223360fc0c2a9d289a8a4e80f11e929 100644 (file)
@@ -228,6 +228,7 @@ class TensorNameMap:
             "transformer_encoder.{bid}.qkv",                                       # neobert
             "layers.{bid}.attn.Wqkv",                                              # modern-bert
             "model.layers.{bid}.self_attn.language_expert_query_key_value",        # cogvlm
+            "model.layers.{bid}.linear_attn.in_proj_qkv",                          # qwen3.5
         ),
 
         # Attention query
@@ -358,8 +359,9 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.ATTN_GATE: (
-            "model.layers.{bid}.self_attn.gate_proj", # afmoe
-            "model.layers.{bid}.self_attn.g_proj",    # step3.5 head-wise attention gate
+            "model.layers.{bid}.self_attn.gate_proj",   # afmoe
+            "model.layers.{bid}.self_attn.g_proj",      # step3.5 head-wise attention gate
+            "model.layers.{bid}.linear_attn.in_proj_z", # qwen3.5
         ),
 
         # Feed-forward norm
index 2115fc4255f04acea32cbc692133b69b5d272bd4..0c164617a12e6f5b3d2baf0e97b6666dba3fa612 100644 (file)
@@ -57,6 +57,7 @@ add_library(llama
             models/deci.cpp
             models/deepseek.cpp
             models/deepseek2.cpp
+            models/delta.cpp
             models/dots1.cpp
             models/dream.cpp
             models/ernie4-5-moe.cpp
@@ -122,6 +123,8 @@ add_library(llama
             models/qwen3vl-moe.cpp
             models/qwen3moe.cpp
             models/qwen3next.cpp
+            models/qwen3-5.cpp
+            models/qwen3-5moe.cpp
             models/refact.cpp
             models/rnd1.cpp
             models/rwkv6-base.cpp
index bd78f1e5562d065d4017721c3740331308ac9cdf..fce46772d7ede14a649e820ac353bdfac9df931c 100644 (file)
@@ -35,6 +35,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_QWEN3,            "qwen3"            },
     { LLM_ARCH_QWEN3MOE,         "qwen3moe"         },
     { LLM_ARCH_QWEN3NEXT,        "qwen3next"        },
+    { LLM_ARCH_QWEN3_5,          "qwen3_5"          },
+    { LLM_ARCH_QWEN3_5_MOE,      "qwen3_5moe"       },
     { LLM_ARCH_QWEN3VL,          "qwen3vl"          },
     { LLM_ARCH_QWEN3VLMOE,       "qwen3vlmoe"       },
     { LLM_ARCH_PHI2,             "phi2"             },
@@ -985,6 +987,63 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_SSM_NORM,
                 LLM_TENSOR_SSM_OUT,
             };
+        case LLM_ARCH_QWEN3_5:
+            return {
+                LLM_TENSOR_TOKEN_EMBD,
+                LLM_TENSOR_OUTPUT_NORM,
+                LLM_TENSOR_OUTPUT,
+                LLM_TENSOR_ATTN_NORM,
+                LLM_TENSOR_ATTN_POST_NORM,
+                LLM_TENSOR_ATTN_Q,
+                LLM_TENSOR_ATTN_Q_NORM,
+                LLM_TENSOR_ATTN_K,
+                LLM_TENSOR_ATTN_K_NORM,
+                LLM_TENSOR_ATTN_V,
+                LLM_TENSOR_ATTN_OUT,
+                LLM_TENSOR_ATTN_QKV,
+                LLM_TENSOR_ATTN_GATE,
+                LLM_TENSOR_FFN_GATE,
+                LLM_TENSOR_FFN_DOWN,
+                LLM_TENSOR_FFN_UP,
+                LLM_TENSOR_SSM_A_NOSCAN,
+                LLM_TENSOR_SSM_CONV1D,
+                LLM_TENSOR_SSM_DT,
+                LLM_TENSOR_SSM_BETA_ALPHA,
+                LLM_TENSOR_SSM_IN,
+                LLM_TENSOR_SSM_NORM,
+                LLM_TENSOR_SSM_OUT,
+            };
+        case LLM_ARCH_QWEN3_5_MOE:
+            return {
+                LLM_TENSOR_TOKEN_EMBD,
+                LLM_TENSOR_OUTPUT_NORM,
+                LLM_TENSOR_OUTPUT,
+                LLM_TENSOR_ATTN_NORM,
+                LLM_TENSOR_ATTN_POST_NORM,
+                LLM_TENSOR_ATTN_Q,
+                LLM_TENSOR_ATTN_Q_NORM,
+                LLM_TENSOR_ATTN_K,
+                LLM_TENSOR_ATTN_K_NORM,
+                LLM_TENSOR_ATTN_V,
+                LLM_TENSOR_ATTN_OUT,
+                LLM_TENSOR_ATTN_QKV,
+                LLM_TENSOR_ATTN_GATE,
+                LLM_TENSOR_FFN_GATE_INP,
+                LLM_TENSOR_FFN_GATE_EXPS,
+                LLM_TENSOR_FFN_DOWN_EXPS,
+                LLM_TENSOR_FFN_UP_EXPS,
+                LLM_TENSOR_FFN_GATE_INP_SHEXP,
+                LLM_TENSOR_FFN_GATE_SHEXP,
+                LLM_TENSOR_FFN_DOWN_SHEXP,
+                LLM_TENSOR_FFN_UP_SHEXP,
+                LLM_TENSOR_SSM_A_NOSCAN,
+                LLM_TENSOR_SSM_CONV1D,
+                LLM_TENSOR_SSM_DT,
+                LLM_TENSOR_SSM_BETA_ALPHA,
+                LLM_TENSOR_SSM_IN,
+                LLM_TENSOR_SSM_NORM,
+                LLM_TENSOR_SSM_OUT,
+            };
         case LLM_ARCH_QWEN3VL:
         case LLM_ARCH_CHAMELEON:
         case LLM_ARCH_HUNYUAN_DENSE:
@@ -2674,6 +2733,8 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
         case LLM_ARCH_NEMOTRON_H:
         case LLM_ARCH_NEMOTRON_H_MOE:
         case LLM_ARCH_QWEN3NEXT:
+        case LLM_ARCH_QWEN3_5:
+        case LLM_ARCH_QWEN3_5_MOE:
         case LLM_ARCH_KIMI_LINEAR:
             return true;
         default:
index e8263369b806ac75608e36f33b38f45f953c6a67..a392ecce2b4c3c77cbdd312cc7f8e08b2d0dd3d5 100644 (file)
@@ -39,6 +39,8 @@ enum llm_arch {
     LLM_ARCH_QWEN3,
     LLM_ARCH_QWEN3MOE,
     LLM_ARCH_QWEN3NEXT,
+    LLM_ARCH_QWEN3_5,
+    LLM_ARCH_QWEN3_5_MOE,
     LLM_ARCH_QWEN3VL,
     LLM_ARCH_QWEN3VLMOE,
     LLM_ARCH_PHI2,
index a6df893a311d382656600f67b3c9bafd5497f1d9..80b9a7d46a6afbed7cc35efcc0d264585fca73c1 100644 (file)
@@ -2013,7 +2013,7 @@ void llama_context::output_reorder() {
 //
 
 uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
-    if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) {
+    if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_QWEN3_5 || model.arch == LLM_ARCH_QWEN3_5_MOE || model.arch == LLM_ARCH_KIMI_LINEAR) {
         return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
     }
     uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
index 674d06c8910d81ae57cc886e750ffeec8597f52f..8fc61aee3727552928110b0f6abe928fb4855a64 100644 (file)
@@ -2412,6 +2412,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_QWEN3_5:
+        case LLM_ARCH_QWEN3_5_MOE:
+            {
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp, false);
+                ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
+
+                // Load linear attention (gated delta net) parameters
+                ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv);
+                ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
+                ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
+                ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+                ml.get_key(LLM_KV_SSM_GROUP_COUNT,    hparams.ssm_n_group);
+
+                // Mark recurrent layers (linear attention layers)
+                for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+                    hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0);
+                }
+            } break;
         case LLM_ARCH_MISTRAL3:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -7094,6 +7113,129 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
                         layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
 
+                        // Shared experts
+                        layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
+                        layer.ffn_gate_shexp     = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP,     "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
+                        layer.ffn_up_shexp       = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,       "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
+                        layer.ffn_down_shexp     = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP,     "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
+                    }
+                } break;
+            case LLM_ARCH_QWEN3_5:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
+                    }
+
+                    // Calculate dimensions from hyperparameters
+                    const int64_t head_k_dim = hparams.ssm_d_state;
+                    const int64_t head_v_dim = hparams.ssm_d_state;
+                    const int64_t n_k_heads  = hparams.ssm_n_group;
+                    const int64_t n_v_heads  = hparams.ssm_dt_rank;
+                    const int64_t key_dim    = head_k_dim * n_k_heads;
+                    const int64_t value_dim  = head_v_dim * n_v_heads;
+                    const int64_t conv_dim   = key_dim * 2 + value_dim;
+
+                    const int64_t ba_dim = n_v_heads * 2;
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm      = create_tensor(tn(LLM_TENSOR_ATTN_NORM,      "weight", i), { n_embd }, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
+
+                        if (!hparams.is_recurrent(i)) {
+                            // Full attention layers
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), { n_embd, n_embd_k_gqa }, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), { n_embd, n_embd_v_gqa }, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+
+                            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
+                            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
+                        } else {
+                            // Linear attention (gated delta net) specific tensors
+                            layer.ssm_in         = create_tensor(tn(LLM_TENSOR_SSM_IN,         "weight", i), { n_embd, key_dim * 2 + value_dim * 2 }, TENSOR_NOT_REQUIRED);
+                            layer.wqkv           = create_tensor(tn(LLM_TENSOR_ATTN_QKV,       "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
+                            layer.wqkv_gate      = create_tensor(tn(LLM_TENSOR_ATTN_GATE,      "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
+                            layer.ssm_conv1d     = create_tensor(tn(LLM_TENSOR_SSM_CONV1D,     "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
+                            layer.ssm_dt         = create_tensor(tn(LLM_TENSOR_SSM_DT,         "bias",   i), { hparams.ssm_dt_rank }, 0);
+                            layer.ssm_a          = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN,             i), { hparams.ssm_dt_rank }, 0);
+                            layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0);
+                            layer.ssm_norm       = create_tensor(tn(LLM_TENSOR_SSM_NORM,       "weight", i), { head_v_dim }, 0);
+                            layer.ssm_out        = create_tensor(tn(LLM_TENSOR_SSM_OUT,        "weight", i), { value_dim, n_embd }, 0);
+                        }
+
+                        // Dense FFN for all layers
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), { n_embd, n_ff }, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+                    }
+                } break;
+            case LLM_ARCH_QWEN3_5_MOE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
+                    }
+
+                    const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+
+                    // Calculate dimensions from hyperparameters
+                    const int64_t head_k_dim = hparams.ssm_d_state;
+                    const int64_t head_v_dim = hparams.ssm_d_state;
+                    const int64_t n_k_heads  = hparams.ssm_n_group;
+                    const int64_t n_v_heads  = hparams.ssm_dt_rank;
+                    const int64_t key_dim    = head_k_dim * n_k_heads;
+                    const int64_t value_dim  = head_v_dim * n_v_heads;
+                    const int64_t conv_dim   = key_dim * 2 + value_dim;
+
+                    const int64_t ba_dim = n_v_heads * 2;
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm      = create_tensor(tn(LLM_TENSOR_ATTN_NORM,      "weight", i), { n_embd }, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
+
+                        if (!hparams.is_recurrent(i)) {
+                            // Full attention layers
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), { n_embd, n_embd_k_gqa }, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), { n_embd, n_embd_v_gqa }, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+
+                            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
+                            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
+                        } else {
+                            // Linear attention (gated delta net) specific tensors
+                            layer.ssm_in         = create_tensor(tn(LLM_TENSOR_SSM_IN,         "weight", i), { n_embd, key_dim * 2 + value_dim * 2 }, TENSOR_NOT_REQUIRED);
+                            layer.wqkv           = create_tensor(tn(LLM_TENSOR_ATTN_QKV,       "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
+                            layer.wqkv_gate      = create_tensor(tn(LLM_TENSOR_ATTN_GATE,      "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
+                            layer.ssm_conv1d     = create_tensor(tn(LLM_TENSOR_SSM_CONV1D,     "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
+                            layer.ssm_dt         = create_tensor(tn(LLM_TENSOR_SSM_DT,         "bias",   i), { hparams.ssm_dt_rank }, 0);
+                            layer.ssm_a          = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN,             i), { hparams.ssm_dt_rank }, 0);
+                            layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0);
+                            layer.ssm_norm       = create_tensor(tn(LLM_TENSOR_SSM_NORM,       "weight", i), { head_v_dim }, 0);
+                            layer.ssm_out        = create_tensor(tn(LLM_TENSOR_SSM_OUT,        "weight", i), { value_dim, n_embd }, 0);
+                        }
+
+                        // MoE FFN
+                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), { n_embd, n_expert }, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+
                         // Shared experts
                         layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
                         layer.ffn_gate_shexp     = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP,     "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
@@ -7545,6 +7687,8 @@ void llama_model::print_info() const {
         arch == LLM_ARCH_PLAMO2 ||
         arch == LLM_ARCH_GRANITE_HYBRID ||
         arch == LLM_ARCH_QWEN3NEXT ||
+        arch == LLM_ARCH_QWEN3_5 ||
+        arch == LLM_ARCH_QWEN3_5_MOE ||
         arch == LLM_ARCH_NEMOTRON_H ||
         arch == LLM_ARCH_NEMOTRON_H_MOE) {
         LLAMA_LOG_INFO("%s: ssm_d_conv            = %u\n",     __func__, hparams.ssm_d_conv);
@@ -8343,6 +8487,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_qwen3next>(*this, params);
             } break;
+        case LLM_ARCH_QWEN3_5:
+            {
+                llm = std::make_unique<llm_build_qwen3_5>(*this, params);
+            } break;
+        case LLM_ARCH_QWEN3_5_MOE:
+            {
+                llm = std::make_unique<llm_build_qwen3_5_moe>(*this, params);
+            } break;
         case LLM_ARCH_MISTRAL3:
             {
                 llm = std::make_unique<llm_build_mistral3>(*this, params);
@@ -8603,6 +8755,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_PANGU_EMBED:
         case LLM_ARCH_AFMOE:
         case LLM_ARCH_QWEN3NEXT:
+        case LLM_ARCH_QWEN3_5:
+        case LLM_ARCH_QWEN3_5_MOE:
         case LLM_ARCH_MIMO2:
         case LLM_ARCH_STEP35:
             return LLAMA_ROPE_TYPE_NEOX;
diff --git a/src/models/delta.cpp b/src/models/delta.cpp
new file mode 100644 (file)
index 0000000..d1d9837
--- /dev/null
@@ -0,0 +1,618 @@
+#include "models.h"
+#include "ggml.h"
+#include <cmath>
+#include <utility>
+#include <cassert>
+
+llm_graph_context_delta::llm_graph_context_delta(const llm_graph_params & params) : llm_graph_context_mamba(params) {}
+
+/**
+ * Unified Delta Net implementation supporting both GDA and KDA modes.
+ *
+ * GDA (Gated Delta Attention): g has shape [H, T, B] in GGML (PyTorch: [B, T, H])
+ *   - Per-head gating, broadcasts over K dimension
+ *
+ * KDA (Key-wise Delta Attention): g has shape [K, H, T, B] in GGML (PyTorch: [B, T, H, K])
+ *   - Per-key gating
+ *
+ * The mode is auto-detected based on g's dimensionality.
+ *
+ * Tensor dimension convention:
+ *   GGML: ne[0] is innermost (fastest varying), ne[3] is outermost
+ *   PyTorch: dim 0 is outermost, dim -1 is innermost
+ *   So GGML [A, B, C, D] corresponds to PyTorch [D, C, B, A]
+ */
+
+// Helper to get a slice along dimension 2 (n_chunks dimension)
+static ggml_tensor * get_slice_2d(ggml_context * ctx, ggml_tensor * t, int64_t chunk) {
+    return ggml_view_4d(ctx, t,
+        t->ne[0], t->ne[1], 1, t->ne[3],
+        t->nb[1], t->nb[2], t->nb[3],
+        chunk * t->nb[2]);
+}
+
+/**
+ * Unified chunked Delta Net implementation.
+ *
+ * Input tensor format matches qwen3next conventions:
+ * @param q         Query tensor [S_k, H_k, n_tokens, n_seqs]
+ * @param k         Key tensor [S_k, H_k, n_tokens, n_seqs]
+ * @param v         Value tensor [S_v, H_v, n_tokens, n_seqs]
+ * @param g         Gate tensor:
+ *                    GDA: [H_v, n_tokens, n_seqs]
+ *                    KDA: [S_k, H_v, n_tokens, n_seqs]
+ * @param beta      Beta tensor [H_v, 1, n_tokens, n_seqs]
+ * @param state     State tensor [S_v, S_v * H_v, 1, n_seqs]
+ * @param causal_mask   Lower triangular mask [chunk_size, chunk_size]
+ * @param identity      Identity matrix [chunk_size, chunk_size]
+ * @param diag_mask     Diagonal mask [chunk_size, chunk_size]
+ * @param il            Layer index (for debugging callbacks)
+ * @param chunk_size    Chunk size for chunked processing
+ * @param eps_norm      Epsilon for L2 normalization
+ *
+ * @return Pair of (output_tokens, new_state)
+ */
+std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net_unified_chunking(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state_reshaped,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il,
+        int64_t       chunk_size,
+        float         eps_norm) {
+
+    // Input format: [S, H, n_tokens, n_seqs] (matching qwen3next convention)
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    // Detect KDA vs GDA based on g's shape
+    // GDA: g has shape [H_v, n_tokens, n_seqs]
+    // KDA: g has shape [S_k, H_v, n_tokens, n_seqs] (4D with ne[0]=S_k)
+    const bool is_kda = (g->ne[0] == S_k && g->ne[1] == H_v);
+
+    // Validate tensor shapes
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(state_reshaped->ne[0] == S_v && state_reshaped->ne[1] == S_v && state_reshaped->ne[2] == H_v && state_reshaped->ne[3] == n_seqs);
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(H_k == H_v);
+
+    if (is_kda) {
+        // KDA: g shape [S_k, H_v, n_tokens, n_seqs]
+        GGML_ASSERT(g->ne[0] == S_k && g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    } else {
+        // GDA: g shape [H_v, n_tokens, n_seqs]
+        GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+    }
+
+    // L2 normalize q and k
+    q = ggml_l2_norm(ctx0, q, eps_norm);
+    k = ggml_l2_norm(ctx0, k, eps_norm);
+
+    const float scale = 1.0f / sqrtf((float)S_v);
+    q = ggml_scale(ctx0, q, scale);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(g, "g_in", il);
+
+    // Permute tensors to working format [S, n_tokens, H, n_seqs]
+    // Input: [S, H, n_tokens, n_seqs] -> permute(0, 2, 1, 3) -> [S, n_tokens, H, n_seqs]
+    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    if (is_kda) {
+        g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    } else {
+        g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
+    }
+    beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
+
+    cb(q, "q_perm", il);
+    cb(k, "k_perm", il);
+    cb(v, "v_perm", il);
+    cb(beta, "beta_perm", il);
+    cb(g, "g_perm", il);
+    cb(state_reshaped, "state_in", il);
+
+    // Padding for chunk processing
+    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
+    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
+
+    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
+    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
+    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
+    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
+    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
+
+
+    cb(q, "q_pad", il);
+    cb(k, "k_pad", il);
+    cb(v, "v_pad", il);
+    cb(beta, "beta_pad", il);
+    cb(g, "g_pad", il);
+
+    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
+    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
+
+    cb(v_beta, "v_beta", il);
+    cb(k_beta, "k_beta", il);
+
+    // Reshape to chunks
+    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
+    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
+    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
+    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
+    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
+    beta   = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
+
+    // Reshape g for chunks
+    ggml_tensor * g_cumsum;
+    ggml_tensor * g_cumsum_t;
+    if (is_kda) {
+        // KDA: g [S_k, n_tokens+pad, H_k, n_seqs] -> [S_k, chunk_size, n_chunks, H_k * n_seqs]
+        g = ggml_reshape_4d(ctx0, g, S_k, chunk_size, n_chunks, H_k * n_seqs);
+        // Cumsum along chunk_size dimension (ne[1])
+        // GGML cumsum operates on ne[0], so we need to transpose, cumsum, transpose back
+        g = ggml_cont(ctx0, ggml_transpose(ctx0, g));  // [chunk_size, S_k, n_chunks, H_k * n_seqs]
+        g_cumsum_t = ggml_cumsum(ctx0, g);
+        g_cumsum = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum_t));  // [S_k, chunk_size, n_chunks, H_k * n_seqs]
+    } else {
+        // GDA: g [n_tokens+pad, 1, H_k, n_seqs] -> [chunk_size, 1, n_chunks, H_k * n_seqs]
+        g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
+        g_cumsum = ggml_cumsum(ctx0, g);
+        g_cumsum_t = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_k * n_seqs);
+    }
+
+    cb(g_cumsum, "g_cumsum", il);
+
+    // Build attention matrix A for the WY representation solve
+    // For GDA: A[j,i] = sum_k(k[j,k] * exp(g[j] - g[i]) * k[i,k]) = (k @ k^T) * exp(g[j] - g[i])
+    // For KDA: A[j,i] = sum_k(k_beta[j,k] * exp(g[j,k] - g[i,k]) * k[i,k])
+    // KDA uses decay mask with S_k packed into batch to compute exp(g[j,k] - g[i,k]) per-key
+
+    ggml_tensor * k_decay;
+    ggml_tensor * decay_mask = nullptr;
+    ggml_tensor * g_exp_pos = nullptr;
+
+    if (is_kda) {
+        // KDA: Use decay mask with S_k in leading dimension for efficient mul_mat reduction
+        // A[j,i] = sum_k(k_beta[j,k] * exp(g[j,k] - g[i,k]) * k[i,k])
+        // By putting S_k in dim 0, mul_mat implicitly sums over it
+
+        const int64_t CHB = n_chunks * H_k * n_seqs;
+
+        // g_cumsum_t is [chunk_size, S_k, n_chunks, H_k * n_seqs]
+        // Reshape to [chunk_size, S_k, CHB] then build decay mask
+        ggml_tensor * gcs = ggml_reshape_3d(ctx0, g_cumsum_t, chunk_size, S_k, CHB);
+        ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, gcs, chunk_size, 1, S_k, CHB);
+        ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, gcs, 1, chunk_size, S_k, CHB);
+
+        // Build decay mask: [chunk_size, chunk_size, S_k, CHB]
+        ggml_tensor * gcs_j_bc = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, S_k, CHB);
+        decay_mask = ggml_sub(ctx0, gcs_j_bc, gcs_i);
+
+        cb(decay_mask, "decay_mask_kda", il);
+
+        decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+        decay_mask = ggml_exp(ctx0, decay_mask);
+        decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+
+        // Permute to [S_k, chunk_size_j, chunk_size_i, CHB] for mul_mat reduction over S_k
+        decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
+
+        // Reshape k and k_beta for broadcasting with decay_mask
+        // k_i: indexed at position i (dim 2 of decay_mask)
+        // k_beta_j: indexed at position j (dim 1 of decay_mask)
+        ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
+        ggml_tensor * k_beta_j = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, 1, CHB);
+
+        // decay_k_beta_j[s,j,i,b] = decay[s,j,i,b] * k_beta[s,j,b]
+        ggml_tensor * decay_k_beta_j = ggml_mul(ctx0, decay_mask, k_beta_j);
+
+        // mul_mat sums over S_k: result[j,1,i,CHB] = sum_s decay_k_beta_j[s,j,i,b] * k_i[s,1,i,b]
+        k_decay = ggml_mul_mat(ctx0, decay_k_beta_j, k_i);
+        k_decay = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, k_decay, chunk_size, chunk_size, n_chunks, H_k * n_seqs)));
+
+        // g_exp_pos is still needed for later (kbeta_gexp, etc.)
+        g_exp_pos = ggml_exp(ctx0, g_cumsum);
+    } else {
+        // GDA: Use decay mask approach (g broadcasts over K dimension)
+        // g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs]
+        ggml_tensor * gcs_i = g_cumsum;
+        ggml_tensor * gcs_j = g_cumsum_t;
+        g_exp_pos = ggml_exp(ctx0, g_cumsum_t);
+        ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
+        decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
+
+        cb(decay_mask, "decay_mask", il);
+
+        decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+        decay_mask = ggml_exp(ctx0, decay_mask);
+        decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+
+        ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
+        k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
+    }
+
+    ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
+
+    cb(attn, "attn_pre_solve", il);
+
+    // Solve triangular system: (I + L) @ X = I, where L is strictly lower triangular
+    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
+    ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
+    ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
+    attn = ggml_mul(ctx0, lin_solve, causal_mask);
+    attn = ggml_add(ctx0, attn, identity);
+
+    cb(attn, "attn_solved", il);
+
+    // Compute u = A @ v and w = A @ (g.exp() * k)
+    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
+
+    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, g_exp_pos);
+    cb(kbeta_gexp, "kbeta_gexp", il);
+
+    ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0,
+        ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
+    cb(k_cumdecay, "k_cumdecay", il);
+
+    // Attention scores q @ k^T with decay
+    // For GDA: attn_kq[j,i] = sum_k(q[j,k] * exp(g[j] - g[i]) * k[i,k])
+    // For KDA: attn_kq[j,i] = sum_k(q[j,k] * exp(g[j,k] - g[i,k]) * k[i,k])
+    ggml_tensor * attn_kq;
+    if (is_kda) {
+        // KDA: Same approach as k_decay - use decay_mask with S_k in leading dim
+        const int64_t CHB = n_chunks * H_k * n_seqs;
+
+        // Rebuild decay mask (same structure as k_decay)
+        ggml_tensor * gcs = ggml_reshape_3d(ctx0, g_cumsum_t, chunk_size, S_k, CHB);
+        ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, gcs, chunk_size, 1, S_k, CHB);
+        ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, gcs, 1, chunk_size, S_k, CHB);
+        ggml_tensor * gcs_j_bc = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, S_k, CHB);
+        ggml_tensor * decay_mask_kq = ggml_sub(ctx0, gcs_j_bc, gcs_i);
+
+        decay_mask_kq = ggml_mul(ctx0, decay_mask_kq, diag_mask);
+        decay_mask_kq = ggml_exp(ctx0, decay_mask_kq);
+        decay_mask_kq = ggml_mul(ctx0, decay_mask_kq, diag_mask);
+
+        // Permute to [S_k, chunk_size_j, chunk_size_i, CHB]
+        decay_mask_kq = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask_kq, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
+
+        // q_j: indexed at position j, k_i: indexed at position i
+        ggml_tensor * q_j = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
+        ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
+
+        // decay_q_j[s,j,i,b] = decay[s,j,i,b] * q[s,j,b]
+        ggml_tensor * decay_q_j = ggml_mul(ctx0, decay_mask_kq, q_j);
+
+        // mul_mat sums over S_k
+        attn_kq = ggml_mul_mat(ctx0, decay_q_j, k_i);
+        attn_kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, attn_kq, chunk_size, chunk_size, n_chunks, H_k * n_seqs)));
+    } else {
+        // GDA: Use decay mask
+        attn_kq = ggml_mul_mat(ctx0, k, q);
+        attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
+        attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
+    }
+    cb(attn_kq, "attn_kq", il);
+
+    // Compute g_last and g_diff for state updates
+    ggml_tensor * g_last;
+    ggml_tensor * g_diff_exp;
+    ggml_tensor * g_last_exp;
+
+    if (is_kda) {
+        // KDA: g_cumsum [S_k, chunk_size, n_chunks, H_k * n_seqs]
+        // Get last element along chunk_size dimension (ne[1])
+        g_last = ggml_view_4d(ctx0, g_cumsum,
+            g_cumsum->ne[0], 1, g_cumsum->ne[2], g_cumsum->ne[3],
+            g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
+            (g_cumsum->ne[1] - 1) * g_cumsum->nb[1]);
+        g_last = ggml_cont(ctx0, g_last);
+        g_last_exp = ggml_exp(ctx0, g_last);
+
+        // g_diff = g_last - g_cumsum
+        ggml_tensor * g_last_broadcast = ggml_repeat_4d(ctx0, g_last,
+            g_cumsum->ne[0], g_cumsum->ne[1], g_cumsum->ne[2], g_cumsum->ne[3]);
+        ggml_tensor * g_diff = ggml_sub(ctx0, g_last_broadcast, g_cumsum);
+        g_diff_exp = ggml_exp(ctx0, g_diff);
+    } else {
+        // GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs]
+        g_last = ggml_view_4d(ctx0, g_cumsum,
+            1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
+            g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
+            (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
+        g_last = ggml_cont(ctx0, g_last);
+        g_last_exp = ggml_exp(ctx0, g_last);
+
+        ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
+        g_diff_exp = ggml_exp(ctx0, g_diff);
+    }
+
+    cb(g_last, "g_last", il);
+    cb(g_last_exp, "g_last_exp", il);
+
+    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
+    cb(key_gdiff, "key_gdiff", il);
+
+    // Process chunks
+    ggml_tensor * new_state = state_reshaped;
+    ggml_tensor * core_attn_out = nullptr;
+
+    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+        ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk);
+        ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk);
+        ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk);
+        ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
+        ggml_tensor * gexp_chunk = get_slice_2d(ctx0, g_exp_pos, chunk);
+
+        cb(attn_chunk, "attn_chunk", il);
+
+        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3),
+            S_v, S_v, 1, H_v * n_seqs);
+
+        // v_prime = k_cumdecay @ state
+        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
+        cb(v_prime, "v_prime_chunk", il);
+
+        // v_new = v - v_prime
+        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
+        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
+        cb(v_new, "v_new_chunk", il);
+
+        // attn_inter = (q * g.exp()) @ state
+        ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
+        cb(attn_inter, "attn_inter_chunk", il);
+
+        // output = attn_inter + attn @ v_new
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
+        cb(v_attn, "v_attn_chunk", il);
+
+        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
+        cb(core_attn_out_chunk, "core_attn_out_chunk", il);
+
+        core_attn_out = core_attn_out == nullptr
+            ? core_attn_out_chunk
+            : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
+
+        // State update: state = state * g_last_exp + key_gdiff^T @ v_new
+        ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
+
+        ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
+
+        if (is_kda) {
+            // KDA: g_last_exp [S_k, 1, n_chunks, H_k * n_seqs]
+            // State: [S_v, S_v, H_v, n_seqs]
+            // Need to reshape g_last_exp to broadcast correctly over V dimension only
+            gexp_last_chunk = ggml_reshape_4d(ctx0, gexp_last_chunk,
+                1, gexp_last_chunk->ne[0], H_v, n_seqs);  // [1, S_k, H_v, n_seqs]
+            // Transpose to [S_k, 1, H_v, n_seqs] then broadcast
+            gexp_last_chunk = ggml_cont(ctx0, ggml_permute(ctx0, gexp_last_chunk, 1, 0, 2, 3));
+        } else {
+            // GDA: g_last_exp [1, 1, n_chunks, H_k * n_seqs]
+            // Broadcasts over both K and V dimensions
+            gexp_last_chunk = ggml_reshape_4d(ctx0, gexp_last_chunk,
+                gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs);
+        }
+
+        new_state = ggml_add(ctx0,
+            ggml_mul(ctx0, new_state, gexp_last_chunk),
+            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
+    }
+
+    // Truncate padding and permute back
+    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
+        S_v, n_tokens, H_v, n_seqs,
+        ggml_row_size(core_attn_out->type, S_v),
+        ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
+        ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
+    output_tokens = ggml_cont(ctx0, output_tokens);
+
+    cb(output_tokens, "output_tokens", il);
+
+    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
+    output_tokens = ggml_cont(ctx0, output_tokens);
+
+    return {output_tokens, new_state};
+}
+
+
+/**
+ * Unified autoregressive Delta Net implementation (single token processing).
+ *
+ * This implementation uses matrix multiplication instead of elementwise operations + summation,
+ * which is more efficient and mathematically equivalent. See inline comments for equivalences.
+ *
+ * Input tensor format matches qwen3next conventions:
+ * @param q         Query tensor [S_k, H_k, 1, n_seqs]
+ * @param k         Key tensor [S_k, H_k, 1, n_seqs]
+ * @param v         Value tensor [S_v, H_v, 1, n_seqs]
+ * @param g         Gate tensor:
+ *                    GDA: [H_v, 1, n_seqs]
+ *                    KDA: [S_k, H_v, 1, n_seqs]
+ * @param beta      Beta tensor [H_v, 1, 1, n_seqs]
+ * @param state     State tensor [S_v, S_v * H_v, 1, n_seqs]
+ * @param il        Layer index (for debugging callbacks)
+ * @param eps_norm  Epsilon for L2 normalization
+ *
+ * @return Pair of (output_tokens, new_state)
+ */
+std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net_unified_autoregressive(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        int           il,
+        float         eps_norm) {
+
+    // Input format: [S, H, n_tokens, n_seqs] (matching qwen3next convention)
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(n_tokens == 1);  // Autoregressive mode is for single token
+
+    // Detect KDA vs GDA based on g's shape
+    // GDA: g has shape [H_v, 1, n_seqs] or [H_v, n_tokens, n_seqs]
+    // KDA: g has shape [S_k, H_v, 1, n_seqs] or [S_k, H_v, n_tokens, n_seqs]
+    const bool is_kda = (g->ne[0] == S_k && g->ne[1] == H_v);
+
+    // Validate shapes
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(H_k == H_v);
+
+    if (is_kda) {
+        GGML_ASSERT(g->ne[0] == S_k && g->ne[1] == H_v);
+    } else {
+        GGML_ASSERT(g->ne[0] == H_v);
+    }
+
+    // L2 normalize q and k
+    q = ggml_l2_norm(ctx0, q, eps_norm);
+    k = ggml_l2_norm(ctx0, k, eps_norm);
+
+    const float scale = 1.0f / sqrtf((float)S_v);
+    q = ggml_scale(ctx0, q, scale);
+    beta = ggml_sigmoid(ctx0, beta);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(g, "g_in", il);
+
+    // Reshape g and beta for broadcasting
+    ggml_tensor * g_t;
+    ggml_tensor * beta_t;
+
+    if (is_kda) {
+        // KDA: g [S_k, H_v, 1, n_seqs] -> [S_k, 1, H_k, n_seqs]
+        // For state multiplication, need [1, S_k, H_v, n_seqs] to broadcast over V only
+        g_t = ggml_reshape_4d(ctx0, g, S_k, 1, H_k, n_seqs);
+    } else {
+        // GDA: g [H_v, 1, n_seqs] -> [1, 1, H_k, n_seqs]
+        // For state multiplication, broadcasts over both K and V
+        g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
+    }
+
+    beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
+
+    // Apply exponential to g_t
+    g_t = ggml_exp(ctx0, g_t);
+
+    // State decay: state = state * exp(g)
+    if (is_kda) {
+        // KDA: g_t [S_k, 1, H_k, n_seqs], state [S_v, S_v, H_v, n_seqs]
+        // Need to broadcast g_t over V dimension (ne[0] of state)
+        // Permute g_t to [1, S_k, H_k, n_seqs] for correct broadcasting
+        ggml_tensor * g_broadcast = ggml_cont(ctx0, ggml_permute(ctx0, g_t, 1, 0, 2, 3));
+        state = ggml_mul(ctx0, state, g_broadcast);
+    } else {
+        // GDA: g_t [1, 1, H_k, n_seqs] broadcasts over both dimensions
+        state = ggml_mul(ctx0, state, g_t);
+    }
+
+    // Equivalence to previous version:
+    // Previous: kv_mem = sum_k(state * k) using elementwise mult + sum_rows
+    // Current:  k_state = state_t @ k_t using matrix multiplication
+    // These are equivalent because: sum_k(A * B) = A @ B when dimensions align
+    ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+    ggml_tensor * k_t = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
+    ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k_t);
+
+    // v_diff = v - k_state (equivalent to v - kv_mem in previous version)
+    ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
+    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, k_state);
+    ggml_tensor * k_beta = ggml_mul(ctx0, k_t, beta_t);
+
+    // Equivalence to previous version:
+    // Previous: state += k.unsqueeze(-1) * delta where delta = (v - kv_mem) * beta
+    // Current:  state += v_diff^T @ k_beta^T using matrix multiplication
+    // These are equivalent because: outer_product(k, v_diff * beta) = v_diff^T @ k^T
+    state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
+
+    // Equivalence to previous version:
+    // Previous: core_attn_out = sum_k(state * q) using elementwise mult + sum_rows
+    // Current:  core_attn_out = state_t @ q using matrix multiplication
+    // These are equivalent because: sum_k(A * B) = A @ B when dimensions align
+    q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
+    state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+    ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
+    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
+    cb(core_attn_out, "output_tokens", il);
+    cb(state, "new_state", il);
+
+    return {core_attn_out, state};
+}
+
+
+/**
+ * Main entry point that dispatches to chunked or autoregressive based on n_tokens.
+ *
+ * Input tensor format matches qwen3next conventions:
+ * @param q         Query tensor [S_k, H_k, n_tokens, n_seqs]
+ * @param k         Key tensor [S_k, H_k, n_tokens, n_seqs]
+ * @param v         Value tensor [S_v, H_v, n_tokens, n_seqs]
+ * @param g         Gate tensor (GDA: [H_v, n_tokens, n_seqs], KDA: [S_k, H_v, n_tokens, n_seqs])
+ * @param beta      Beta tensor [H_v, 1, n_tokens, n_seqs]
+ * @param state     State tensor [S_v, S_v * H_v, 1, n_seqs]
+ */
+std::pair<ggml_tensor *, ggml_tensor *> llm_graph_context_delta::build_delta_net_unified(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il,
+        int64_t       chunk_size,
+        float         eps_norm) {
+
+    // Input format: [S, H, n_tokens, n_seqs] (matching qwen3next convention)
+    const int64_t n_tokens = q->ne[2];
+
+    if (n_tokens == 1) {
+        return build_delta_net_unified_autoregressive(
+            ctx0, q, k, v, g, beta, state, il, eps_norm);
+    }
+    return build_delta_net_unified_chunking(
+        ctx0, q, k, v, g, beta, state, causal_mask, identity, diag_mask,
+        il, chunk_size, eps_norm);
+}
index 0f037d1a39324a449333731a808ca5da55885455..d9ee6980751616804e840066970181ec2fe4049a 100644 (file)
@@ -1,5 +1,4 @@
 #include "models.h"
-#include "ggml.h"
 
 #define CHUNK_SIZE 64
 
index cfcbb9aaa5b13d9e81446b2dd6408d2d1d89e2b3..2a750c168ea986f6609d33819b5cbd698f441c67 100644 (file)
@@ -17,6 +17,53 @@ struct llm_graph_context_mamba : public llm_graph_context {
 
 };
 
+struct llm_graph_context_delta : public llm_graph_context_mamba {
+    llm_graph_context_delta(const llm_graph_params & params);
+
+    virtual ~llm_graph_context_delta() = default;
+
+    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_unified_chunking(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il,
+        int64_t       chunk_size,
+        float         eps_norm);
+
+    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_unified_autoregressive(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        int           il,
+        float         eps_norm);
+
+    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_unified(
+        ggml_context * ctx0,
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il,
+        int64_t       chunk_size,
+        float         eps_norm);
+};
+
 // Base class for RWKV-related models
 struct llm_build_rwkv6_base : public llm_graph_context {
     const llama_model & model;
@@ -476,7 +523,7 @@ struct llm_build_qwen3vl : public llm_graph_context {
 struct llm_build_qwen3vlmoe : public llm_graph_context {
     llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
 };
-struct llm_build_qwen3next : public llm_graph_context_mamba {
+struct llm_build_qwen3next : public llm_graph_context_delta {
     llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
 private:
     ggml_tensor * build_layer_attn(
@@ -534,6 +581,59 @@ private:
     const llama_model & model;
 };
 
+struct llm_build_qwen3_5 : public llm_graph_context_delta {
+    llm_build_qwen3_5(const llama_model & model, const llm_graph_params & params);
+
+protected:
+    // Tag type for subclass constructors that need to call build_graph() themselves
+    // (to ensure virtual dispatch works correctly)
+    struct defer_graph_build_t {};
+
+    llm_build_qwen3_5(const llama_model & model, const llm_graph_params & params, defer_graph_build_t);
+
+    void build_graph();
+
+    virtual ggml_tensor * build_layer_ffn(
+                ggml_tensor * cur,
+                        int   il);
+
+    const llama_model & model;
+
+private:
+    ggml_tensor * build_layer_attn(
+    llm_graph_input_attn_kv * inp_attn,
+                ggml_tensor * cur,
+                ggml_tensor * inp_pos,
+                        int   il);
+
+    ggml_tensor * build_layer_attn_linear(
+         llm_graph_input_rs * inp,
+                ggml_tensor * cur,
+                ggml_tensor * causal_mask,
+                ggml_tensor * identity,
+                ggml_tensor * diag_mask,
+                        int   il);
+
+    ggml_tensor * build_norm_gated(
+                ggml_tensor * input,
+                ggml_tensor * weights,
+                ggml_tensor * gate,
+                        int   layer);
+
+    std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(
+                ggml_tensor * input,
+                        int   il);
+};
+
+struct llm_build_qwen3_5_moe : public llm_build_qwen3_5 {
+    llm_build_qwen3_5_moe(const llama_model & model, const llm_graph_params & params);
+
+protected:
+    ggml_tensor * build_layer_ffn(
+                ggml_tensor * cur,
+                        int   il) override;
+};
+
 struct llm_build_qwen : public llm_graph_context {
     llm_build_qwen(const llama_model & model, const llm_graph_params & params);
 };
diff --git a/src/models/qwen3-5.cpp b/src/models/qwen3-5.cpp
new file mode 100644 (file)
index 0000000..0947299
--- /dev/null
@@ -0,0 +1,421 @@
+#include "models.h"
+
+#define CHUNK_SIZE 64
+
+llm_build_qwen3_5::llm_build_qwen3_5(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context_delta(params), model(model) {
+    build_graph();
+}
+
+// virtual call in constructor fix
+llm_build_qwen3_5::llm_build_qwen3_5(const llama_model & model, const llm_graph_params & params, defer_graph_build_t /*tag*/) :
+    llm_graph_context_delta(params), model(model) {
+}
+
+void llm_build_qwen3_5::build_graph() {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "model.embed_tokens", -1);
+
+    auto * inp = build_inp_mem_hybrid();
+
+    ggml_tensor * inp_pos     = build_inp_pos();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    ggml_tensor * causal_mask =
+        ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
+                    GGML_TRI_TYPE_LOWER);
+
+    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
+    ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
+
+    ggml_build_forward_expand(gf, causal_mask);
+    ggml_build_forward_expand(gf, identity);
+    ggml_build_forward_expand(gf, diag_mask);
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        if (hparams.is_recurrent(il)) {
+            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
+        } else {
+            cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        cur = ggml_add(ctx0, cur, inpSA);
+        cb(cur, "attn_residual", il);
+
+        ggml_tensor * ffn_residual = cur;
+
+        ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
+        cb(attn_post_norm, "attn_post_norm", il);
+
+        cur = build_layer_ffn(attn_post_norm, il);
+        cb(cur, "ffn_out", il);
+
+        cur = ggml_add(ctx0, cur, ffn_residual);
+        cb(cur, "post_ffn", il);
+
+        inpL = cur;
+    }
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+ggml_tensor * llm_build_qwen3_5::build_norm_gated(
+        ggml_tensor * input,
+        ggml_tensor * weights,
+        ggml_tensor * gate,
+        int           layer) {
+    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
+    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
+
+    return ggml_mul(ctx0, normalized, gated_silu);
+}
+
+ggml_tensor * llm_build_qwen3_5::build_layer_attn(
+        llm_graph_input_attn_kv * inp,
+        ggml_tensor *             cur,
+        ggml_tensor *             inp_pos,
+        int                       il) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+    ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ]
+    cb(Qcur_full, "Qcur_full", il);
+
+    ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0);
+    cb(Qcur, "Qcur_reshaped", il);
+
+    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Qcur, "Qcur_normed", il);
+
+    ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+    cb(Kcur, "Kcur", il);
+
+    ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+    cb(Vcur, "Vcur", il);
+
+    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Kcur, "Kcur_normed", il);
+
+    ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
+        ggml_element_size(Qcur_full) * n_embd_head);
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+    cb(gate, "gate_reshaped", il);
+
+    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+    Qcur = ggml_rope_ext(
+            ctx0, Qcur, inp_pos, nullptr,
+            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow);
+
+    Kcur = ggml_rope_ext(
+            ctx0, Kcur, inp_pos, nullptr,
+            n_rot, rope_type, n_ctx_orig, freq_base,
+            freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+
+    cb(Qcur, "Qcur", il);
+    cb(Kcur, "Kcur", il);
+    cb(Vcur, "Vcur", il);
+
+    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+    cur = build_attn(inp,
+                nullptr, nullptr,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+    cb(cur, "attn_pregate", il);
+
+    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
+    cb(gate_sigmoid, "gate_sigmoid", il);
+
+    cur = ggml_mul(ctx0, cur, gate_sigmoid);
+    cb(cur, "attn_gated", il);
+
+    cur = build_lora_mm(model.layers[il].wo, cur);
+    cb(cur, "attn_output", il);
+
+    return cur;
+}
+
+std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3_5::build_qkvz(
+                ggml_tensor * input,
+                        int   il) {
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    if (model.layers[il].wqkv) {
+        ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
+        qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
+        cb(qkv_mixed, "linear_attn_qkv_mixed", il);
+
+        ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
+        cb(z, "z", il);
+
+        return { qkv_mixed, z };
+
+    }
+    // legacy path for combined in_proj_qkvz
+    ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
+    cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
+
+    int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
+    ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
+
+    int64_t split_sizes_qkvz[4] = {
+        head_k_dim,
+        head_k_dim,
+        head_v_dim * num_v_heads / num_k_heads,
+        head_v_dim * num_v_heads / num_k_heads
+    };
+
+    ggml_tensor * query =
+        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
+                    mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
+    cb(query, "q", il);
+
+    ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
+                                    mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                                    split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
+    cb(key, "k", il);
+
+    ggml_tensor * value =
+        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
+                    mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                    (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
+    cb(value, "v", il);
+
+    ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
+                                mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                                (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
+    z = ggml_cont(ctx0, z);
+    cb(z, "z", il);
+
+    ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+    cb(query_flat, "query_flat", il);
+
+    ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+    cb(key_flat, "key_flat", il);
+
+    ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    cb(value_flat, "value_flat", il);
+
+    ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
+    qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
+    cb(qkv_mixed, "qkv_mixed", il);
+
+    return { qkv_mixed, z };
+}
+
+ggml_tensor * llm_build_qwen3_5::build_layer_attn_linear(
+        llm_graph_input_rs * inp,
+        ggml_tensor *        cur,
+        ggml_tensor *        causal_mask,
+        ggml_tensor *        identity,
+        ggml_tensor *        diag_mask,
+        int                  il) {
+    const auto * mctx_cur = inp->mctx;
+
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    const auto kv_head = mctx_cur->get_head();
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    auto qkvz = build_qkvz(cur, il);
+    ggml_tensor * qkv_mixed = qkvz.first;
+    ggml_tensor * z         = qkvz.second;
+
+    ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
+    cb(mixed_ba, "linear_attn_mixed_ba", il);
+
+    int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
+    ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
+
+    int64_t split_sizes_ba[2] = {
+        num_v_heads / num_k_heads,
+        num_v_heads / num_k_heads
+    };
+
+    ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs,
+                                   mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0);
+    cb(b, "b", il);
+
+    ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs,
+                                   mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3],
+                                   split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
+    cb(a, "a", il);
+
+    ggml_tensor * beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
+
+    ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
+
+    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
+    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
+    cb(alpha_softplus, "a_softplus", il);
+    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);
+    cb(gate, "gate", il);
+
+    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+    cb(conv_states, "conv_states", il);
+
+    ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
+    const int64_t conv_kernel_size = conv_kernel->ne[0];
+    const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
+    conv_states                    = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
+    cb(conv_states, "conv_states_reshaped", il);
+
+    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
+    cb(qkv_mixed, "qkv_mixed_permuted", il);
+
+    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
+    cb(conv_input, "conv_input", il);
+
+    ggml_tensor * last_conv_states =
+        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
+                     conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
+    cb(last_conv_states, "last_conv_states", il);
+
+    ggml_tensor * state_update_target =
+        ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
+                     kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+    cb(state_update_target, "state_update_target", il);
+
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+    cb(conv_states_all, "conv_states_updated", il);
+
+    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
+    cb(conv_output_proper, "conv_output_raw", il);
+
+    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
+    cb(conv_output_silu, "conv_output_silu", il);
+
+    ggml_tensor * conv_qkv_mix = conv_output_silu;
+
+    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
+    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
+
+    ggml_tensor * q_conv =
+        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
+    cb(q_conv, "q_conv", il);
+    ggml_tensor * k_conv =
+        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
+                     head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+    cb(k_conv, "k_conv", il);
+    ggml_tensor * v_conv =
+        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
+                     2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+    cb(v_conv, "v_conv", il);
+
+    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim,  num_v_heads, n_seqs);
+    cb(state, "state_predelta", il);
+
+    if (num_k_heads != num_v_heads) {
+        GGML_ASSERT(num_v_heads % num_k_heads == 0);
+        int64_t repeat_factor = num_v_heads / num_k_heads;
+
+        ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
+        ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
+
+        ggml_tensor * q_repeated =
+            ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
+        ggml_tensor * k_repeated =
+            ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
+
+        q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
+        k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
+    }
+
+    cb(q_conv, "q_conv_predelta", il);
+    cb(k_conv, "k_conv_predelta", il);
+    cb(v_conv, "v_conv_predelta", il);
+
+    std::pair<ggml_tensor *, ggml_tensor *> attn_out = build_delta_net_unified(ctx0, q_conv, k_conv, v_conv,
+            gate, beta, state, causal_mask, identity, diag_mask,
+            il, CHUNK_SIZE, hparams.f_norm_rms_eps);
+
+    ggml_tensor * output    = attn_out.first;
+    ggml_tensor * new_state = attn_out.second;
+    cb(output, "attn_output", il);
+    cb(new_state, "new_state", il);
+
+    ggml_build_forward_expand(gf,
+                              ggml_cpy(ctx0, new_state,
+                                       ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                                                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+
+    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+
+    ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
+
+    ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    cb(final_output, "final_output", il);
+
+    cur = build_lora_mm(model.layers[il].ssm_out, final_output);
+    cb(cur, "linear_attn_out", il);
+
+    cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen3_5::build_layer_ffn(ggml_tensor * cur, const int il) {
+    // Qwen3.5 Dense always uses dense FFN
+    cur = build_ffn(cur,
+        model.layers[il].ffn_up, NULL, NULL,
+        model.layers[il].ffn_gate, NULL, NULL,
+        model.layers[il].ffn_down, NULL, NULL,
+        NULL,
+        LLM_FFN_SILU, LLM_FFN_PAR, il);
+    cb(cur, "ffn_out", il);
+    return cur;
+}
diff --git a/src/models/qwen3-5moe.cpp b/src/models/qwen3-5moe.cpp
new file mode 100644 (file)
index 0000000..a488443
--- /dev/null
@@ -0,0 +1,52 @@
+#include "models.h"
+
+llm_build_qwen3_5_moe::llm_build_qwen3_5_moe(const llama_model & model, const llm_graph_params & params) :
+    llm_build_qwen3_5(model, params, defer_graph_build_t{}) {
+    build_graph();
+}
+
+ggml_tensor * llm_build_qwen3_5_moe::build_layer_ffn(ggml_tensor * cur, const int il) {
+    // Check if this is an MoE layer
+    if (model.layers[il].ffn_gate_inp != nullptr) {
+        // MoE branch
+        ggml_tensor * moe_out =
+            build_moe_ffn(cur,
+                model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
+                model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
+                nullptr,
+                n_expert, n_expert_used, LLM_FFN_SILU,
+                true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
+        cb(moe_out, "ffn_moe_out", il);
+
+        // Add shared experts if present
+        if (model.layers[il].ffn_up_shexp != nullptr) {
+            ggml_tensor * ffn_shexp =
+                build_ffn(cur,
+                    model.layers[il].ffn_up_shexp, NULL, NULL,
+                    model.layers[il].ffn_gate_shexp, NULL, NULL,
+                    model.layers[il].ffn_down_shexp, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(ffn_shexp, "ffn_shexp", il);
+
+            // Apply shared expert gating (sigmoid)
+            ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
+            cb(shared_gate, "shared_expert_gate", il);
+
+            shared_gate = ggml_sigmoid(ctx0, shared_gate);
+            cb(shared_gate, "shared_expert_gate_sigmoid", il);
+
+            ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
+            cb(ffn_shexp, "ffn_shexp_gated", il);
+
+            cur = ggml_add(ctx0, moe_out, ffn_shexp);
+            cb(cur, "ffn_out", il);
+        } else {
+            cur = moe_out;
+        }
+    } else {
+        // Dense FFN branch (fallback)
+        cur = llm_build_qwen3_5::build_layer_ffn(cur, il);
+    }
+    return cur;
+}
index 99b1a76a485e8cd373cad6bf60fc5b1753d5a74b..0335f5ab7664a7acca910fc3471af1e795ed009c 100644 (file)
@@ -1,10 +1,9 @@
-#include "ggml.h"
 #include "models.h"
 
 #define CHUNK_SIZE 64
 
 llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params), model(model) {
+    llm_graph_context_delta(params), model(model) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
@@ -86,362 +85,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     ggml_build_forward_expand(gf, cur);
 }
 
-// utility to get one slice from the third dimension
-// input dim:  [x, y, c, b]
-// output dim: [x, y, 1, b]
-static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
-    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
-        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
-}
-
-std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chunking(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        ggml_tensor * causal_mask,
-        ggml_tensor * identity,
-        ggml_tensor * diag_mask,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q = ggml_scale(ctx0, q, scale);
-
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(g, "g_in", il);
-
-    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
-
-    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    cb(q, "q_perm", il);
-    cb(k, "k_perm", il);
-    cb(v, "v_perm", il);
-    cb(beta, "beta_perm", il);
-    cb(g, "g_perm", il);
-    cb(state, "state_in", il);
-
-    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
-
-    // Do padding
-    const int64_t chunk_size = CHUNK_SIZE;
-
-    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
-    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
-
-    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
-    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
-    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
-    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
-    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
-
-    cb(q, "q_pad", il);
-    cb(k, "k_pad", il);
-    cb(v, "v_pad", il);
-    cb(beta, "beta_pad", il);
-    cb(g, "g_pad", il);
-
-    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
-    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
-
-    cb(v_beta, "v_beta", il);
-    cb(k_beta, "k_beta", il);
-
-    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
-    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
-    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
-
-    g    = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
-    beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
-
-    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
-    cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
-    ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
-
-    ggml_tensor * gcs_j_broadcast =
-        ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
-
-    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
-    cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-    decay_mask = ggml_exp(ctx0, decay_mask);
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-
-    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
-
-    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
-    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
-    cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
-    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
-
-    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
-    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
-    attn                     = ggml_add(ctx0, attn, identity);
-    cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
-
-    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
-    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
-
-    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
-    cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * k_cumdecay =
-        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
-    cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
-    attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
-    attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
-    cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-
-    // vectorized calculation of key_gdiff
-    // improved from the chunked version:
-    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
-    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
-    //   key_gdiff = key * g_diff.unsqueeze(-1)
-    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-
-    // get last element in g_cumsum along chunk_size dimension (ne0)
-    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
-    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
-                                        g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
-                                        (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
-    g_last = ggml_cont(ctx0, g_last);
-    cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
-    cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
-    cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-    ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
-                                                 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
-
-    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
-    cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
-    cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
-
-
-    // state to be updated per chunk
-    ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
-    cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
-
-    // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
-    ggml_tensor * core_attn_out = nullptr;
-
-    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
-        // shape: (S_k, chunk_size, 1, H_k * n_seqs)
-        ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
-
-        // shape: (S_v, chunk_size, 1, H_v * n_seqs)
-        ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
-
-        // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
-        ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
-
-        // shape: (chunk_size, 1, H_v * n_seqs)
-        ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
-
-        // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
-        // replaced by precomputed attn_kq
-        ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
-        cb(attn_chunk, "attn_chunk", il);
-
-        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
-
-        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
-        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
-        cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
-
-        // v_new = v_i - v_prime
-        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
-        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
-        cb(v_new, "v_new_chunk", il);
-
-        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-        ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
-        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
-        cb(attn_inter, "attn_inter_chunk", il);
-
-        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
-        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
-        cb(v_attn, "v_attn_chunk", il);
-
-        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
-        cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
-
-        core_attn_out = core_attn_out == nullptr
-            ? core_attn_out_chunk
-            : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
-
-        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
-        //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
-
-        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-        ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
-        new_state = ggml_add(ctx0,
-            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
-            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
-    }
-
-    // truncate padded tokens
-    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
-            S_v, n_tokens, H_v, n_seqs,
-            ggml_row_size(core_attn_out->type, S_v),
-            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
-            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
-    output_tokens = ggml_cont(ctx0, output_tokens);
-    cb(output_tokens, "output_tokens", il);
-
-    // permute back to (S_v, H_v, n_tokens, n_seqs)
-    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
-    output_tokens = ggml_cont(ctx0, output_tokens);
-
-    return {output_tokens, new_state};
-}
-
-std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(n_tokens == 1);  // This function is optimized for single token processing
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q    = ggml_scale(ctx0, q, scale);
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(g, "g_in", il);
-
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    ggml_tensor * g_t    = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
-    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
-
-    // Apply exponential to g_t
-    g_t = ggml_exp(ctx0, g_t);
-
-    // Apply the gated delta rule for the single timestep
-    // last_recurrent_state = last_recurrent_state * g_t
-    state = ggml_mul(ctx0, state, g_t);
-
-    // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
-    ggml_tensor * kv_mem         = ggml_mul(ctx0, state, k_t_unsqueezed);
-    // we need to sum over dim=-2, so we transpose, sum, then transpose again
-    kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
-
-    // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
-    ggml_tensor * v_t    = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
-    // delta = (v_t - kv_mem) * beta_t
-    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);  // both should be [S_v, 1, H_v, n_seqs]
-    ggml_tensor * delta  = ggml_mul(ctx0, v_diff, beta_t);
-
-    // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
-    ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
-    state                   = ggml_add(ctx0, state, k_t_delta);
-
-    // Compute the attention output
-    // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);  // unsqueeze q_t
-    ggml_tensor * state_q        = ggml_mul(ctx0, state, q_t_unsqueezed);
-    // again, since it's over dim = -2, transpose, sum, transpose back
-    ggml_tensor * core_attn_out =
-        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
-
-    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
-    cb(core_attn_out, "output_tokens", il);
-    cb(state, "new_state", il);
-
-    return {core_attn_out, state};
-}
-
 ggml_tensor * llm_build_qwen3next::build_norm_gated(
         ggml_tensor * input,
         ggml_tensor * weights,
@@ -752,7 +395,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
-    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
+    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim,  num_v_heads, n_seqs);
     cb(state, "state_predelta", il);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
@@ -781,13 +424,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(k_conv, "k_conv_predelta", il);
     cb(v_conv, "v_conv_predelta", il);
 
-    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
-    std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
-    if (n_seq_tokens == 1) {
-        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
-    } else {
-        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
-    }
+    std::pair<ggml_tensor *, ggml_tensor *> attn_out = build_delta_net_unified(ctx0, q_conv, k_conv, v_conv,
+            gate, beta, state, causal_mask, identity, diag_mask,
+            il, CHUNK_SIZE, hparams.f_norm_rms_eps);
+
     ggml_tensor * output    = attn_out.first;
     ggml_tensor * new_state = attn_out.second;
     cb(output, "attn_output", il);