]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
models : fix assert in mamba2 graph (#20270)
authorGeorgi Gerganov <redacted>
Mon, 9 Mar 2026 11:15:15 +0000 (13:15 +0200)
committerGitHub <redacted>
Mon, 9 Mar 2026 11:15:15 +0000 (13:15 +0200)
src/models/mamba-base.cpp

index 8aedbef84e75a1ef917ac3eb86c39ca05c160613..8a79fe4b6cdbb1442e0ab4788a13b0a4886c999c 100644 (file)
@@ -155,7 +155,6 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
 
     const auto kv_head = mctx_cur->get_head();
 
-    const int64_t n_embd   = hparams.n_embd;
     const int64_t d_conv   = hparams.ssm_d_conv;
     const int64_t d_inner  = hparams.ssm_d_inner;
     const int64_t d_state  = hparams.ssm_d_state;
@@ -170,7 +169,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
     GGML_ASSERT(ubatch.equal_seqs());
     GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
     GGML_ASSERT(d_inner % n_head == 0);
-    GGML_ASSERT(d_inner % (n_group*n_embd) == 0);
+    GGML_ASSERT(d_inner % (n_group*d_state) == 0);
 
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);