]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : Granite MoE shared (#13269)
authorGabe Goodhart <redacted>
Tue, 13 May 2025 13:12:01 +0000 (07:12 -0600)
committerGitHub <redacted>
Tue, 13 May 2025 13:12:01 +0000 (15:12 +0200)
* feat: Add GGUF conversion for granitemoeshared

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
* feat: hparam and arch plumbing for granitemoeshared

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
* fix: Split MoE fused tensors for shared experts in conversion

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
* feat: First WIP cut at model arch in cpp

The hparam and architecture plumbing should be correct, but the
implementation of the shared experts seems to still be broken.

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
* fix: Cleaner (maybe more correct?) splitting for gate/up

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
* fix: Fix the input to the shared experts

I had misread that the shared experts take the inputs _before_ the standard
MoE layer and was feeding the output of the MoE to the shared experts.

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
* fix: Avoid architecture-specific checks for Granite MoE Shared

This is a cleaner way that will allow more flexibility in architecture
strings going forward.

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
* refactor: Split granite architectures out of llm_build_llama

This helps de-clutter the llama-family graph construction and allows
granite to diverge further (in preparation for Granite 4).

NOTE: I removed the granite scale factors from llm_build_deci because they
appear to only be there as copy-paste from llm_build_llama. The HF config
does not seem to set those values:
https://huggingface.co/Deci/DeciLM-7B/blob/main/config.json

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
* fix: Fix compiler warning about uninitialized inp_pos

This should not have been reachable, but it warns on some compliers

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
* fix: Consoladate GraniteMoEShared into GraniteMoE for conversion

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
* fix: Consolidate GraniteMoEShared into GraniteMoE on the c++ side

Branch: GraniteMoEShared

Signed-off-by: Gabe Goodhart <redacted>
---------

Signed-off-by: Gabe Goodhart <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-model.cpp

index a34ba2988238a5d7fdaf1f6648f3293618e47717..68b5e87992383db295d61b572f892a86e4b101ca 100755 (executable)
@@ -5746,11 +5746,20 @@ class GraniteModel(LlamaModel):
             logger.info("gguf: (granite) logits_scale = %s", logits_scale)
 
 
-@ModelBase.register("GraniteMoeForCausalLM")
+@ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM")
 class GraniteMoeModel(GraniteModel):
     """Conversion for IBM's GraniteMoeForCausalLM"""
     model_arch = gguf.MODEL_ARCH.GRANITE_MOE
 
+    def set_gguf_parameters(self):
+        """GraniteMoeShared uses GraniteMoe parameters plus the following:
+        - shared_intermediate_size
+        """
+        super().set_gguf_parameters()
+        if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
+            self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
+            logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)
+
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         """In modeling_granitemoe, the JetMoe implementation of parallel experts
         is used. This essentially merges w1 and w3 into a single tensor with 2x
@@ -5761,12 +5770,21 @@ class GraniteMoeModel(GraniteModel):
         if name.endswith("block_sparse_moe.input_linear.weight"):
             ffn_dim = self.hparams["intermediate_size"]
             assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
-            gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
+            gate, up = data_torch.split(ffn_dim, dim=-2)
             return [
                 (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
                 (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
             ]
 
+        if name.endswith("shared_mlp.input_linear.weight"):
+            ffn_dim = self.hparams["shared_intermediate_size"]
+            assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
+            gate, up = data_torch.split(ffn_dim, dim=-2)
+            return [
+                (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
+                (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
+            ]
+
         return super().modify_tensors(data_torch, name, bid)
 
 
index 0e6226b900db9363a57c1ae9908b3bf0740435d3..21af0a9a2693f89dc794d98f49776ee93fa6fcc7 100644 (file)
@@ -1905,6 +1905,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_GATE_EXP,
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_GATE_SHEXP,
+        MODEL_TENSOR.FFN_UP_SHEXP,
+        MODEL_TENSOR.FFN_DOWN_SHEXP,
     ],
     MODEL_ARCH.CHAMELEON: [
         MODEL_TENSOR.TOKEN_EMBD,
index ecf21b2b441425a6f3bb5e681bcdc0315cb2fc50..2629b3c1ab428bdc8dc078166c9f0d1e962ec767 100644 (file)
@@ -428,6 +428,7 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.shared_expert.down_proj",  # qwen2moe
             "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
             "language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
+            "model.layers.{bid}.shared_mlp.output_linear",     # granitemoe
         ),
 
         MODEL_TENSOR.ATTN_Q_NORM: (
index f2bc8ca76850278ce2f4b320300e503dec4158cc..abf436adac41665824c97c062d2a0c8349785b2d 100644 (file)
@@ -1481,6 +1481,9 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
             { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_GATE_SHEXP,  "blk.%d.ffn_gate_shexp" },
+            { LLM_TENSOR_FFN_DOWN_SHEXP,  "blk.%d.ffn_down_shexp" },
+            { LLM_TENSOR_FFN_UP_SHEXP,    "blk.%d.ffn_up_shexp" },
         },
     },
     {
index 3a4e72a36b0730417d8dd670b706053aa3f01b88..f652f4b861d1ffac74c39c5f2a017a7f7241768d 100644 (file)
@@ -1389,6 +1389,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     // Add additional layer/vocab/etc checks here for other model sizes
                     default: type = LLM_TYPE_UNKNOWN;
                 }
+
+                // For Granite MoE Shared
+                ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
             } break;
         case LLM_ARCH_CHAMELEON:
             {
@@ -1772,6 +1775,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                             layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, TENSOR_NOT_REQUIRED);
                             layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
                             layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
+
+                            // For Granite MoE Shared
+                            if (hparams.n_ff_shexp > 0) {
+                                layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
+                                layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
+                                layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
+                            }
                         }
                     }
                 } break;
@@ -4385,10 +4395,13 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
     }
 
-    if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) {
+    if (arch == LLM_ARCH_MINICPM ||
+        arch == LLM_ARCH_GRANITE ||
+        arch == LLM_ARCH_GRANITE_MOE) {
         LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
         LLAMA_LOG_INFO("%s: f_residual_scale  = %f\n", __func__, hparams.f_residual_scale);
         LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
+        LLAMA_LOG_INFO("%s: n_ff_shexp        = %d\n", __func__, hparams.n_ff_shexp);
     }
 
     if (arch == LLM_ARCH_BAILINGMOE) {
@@ -4598,11 +4611,6 @@ struct llm_build_llama : public llm_graph_context {
                 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
 
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
@@ -4674,11 +4682,6 @@ struct llm_build_llama : public llm_graph_context {
                 cb(cur, "ffn_moe_out", il);
             }
 
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
@@ -4701,11 +4704,6 @@ struct llm_build_llama : public llm_graph_context {
         // lm_head
         cur = build_lora_mm(model.output, cur);
 
-        // For Granite architecture
-        if (hparams.f_logit_scale) {
-            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
-        }
-
         cb(cur, "result_output", -1);
         res->t_logits = cur;
 
@@ -4816,11 +4814,6 @@ struct llm_build_deci : public llm_graph_context {
                 continue;
             }
 
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
             // modified to support attention-free layer of Llama-3_1-Nemotron-51B
             ggml_tensor * ffn_inp = cur;
             if (n_head > 0) {
@@ -4844,11 +4837,6 @@ struct llm_build_deci : public llm_graph_context {
                 cb(cur, "ffn_out", il);
             }
 
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
@@ -4871,11 +4859,6 @@ struct llm_build_deci : public llm_graph_context {
         // lm_head
         cur = build_lora_mm(model.output, cur);
 
-        // For Granite architecture
-        if (hparams.f_logit_scale) {
-            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
-        }
-
         cb(cur, "result_output", -1);
         res->t_logits = cur;
 
@@ -12214,6 +12197,195 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
     }
 };
 
+
+struct llm_build_granite : public llm_graph_context {
+    llm_build_granite(
+        const llama_model & model,
+        const llm_graph_params & params,
+        ggml_cgraph * gf,
+        const bool use_rope = true)
+        : llm_graph_context(params) {
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - built only if rope enabled
+        ggml_tensor * inp_pos = nullptr;
+
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and (optionally) RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                if (use_rope) {
+
+                    if (!inp_pos) {
+                        inp_pos = build_inp_pos();
+                    }
+                    ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                    Qcur = ggml_rope_ext(
+                            ctx0, Qcur, inp_pos, rope_factors,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                            );
+
+                    Kcur = ggml_rope_ext(
+                            ctx0, Kcur, inp_pos, rope_factors,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                            );
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // For Granite architectures - scale residual
+            cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network (non-MoE)
+            if (model.layers[il].ffn_gate_inp == nullptr) {
+
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+
+            } else {
+                // MoE branch
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                ggml_tensor * moe_out = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                // For Granite MoE Shared
+                if (hparams.n_ff_shexp > 0) {
+                    ggml_tensor * ffn_shexp = build_ffn(cur,
+                        model.layers[il].ffn_up_shexp,   NULL, NULL,
+                        model.layers[il].ffn_gate_shexp, NULL, NULL,
+                        model.layers[il].ffn_down_shexp, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                    cb(cur, "ffn_out", il);
+                } else {
+                    cur = moe_out;
+                }
+            }
+
+            // For Granite architectures - scale residual
+            cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        // For Granite architectures - scale logits
+        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 // ref: https://github.com/facebookresearch/chameleon
 // based on the original build_llama() function, changes:
 //   * qk-norm
@@ -12921,8 +13093,6 @@ llm_graph_result_ptr llama_model::build_graph(
         case LLM_ARCH_LLAMA:
         case LLM_ARCH_LLAMA4:
         case LLM_ARCH_MINICPM:
-        case LLM_ARCH_GRANITE:
-        case LLM_ARCH_GRANITE_MOE:
             {
                 llm = std::make_unique<llm_build_llama>(*this, params, gf);
             } break;
@@ -13153,6 +13323,11 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_arwkv7>(*this, params, gf);
             } break;
+        case LLM_ARCH_GRANITE:
+        case LLM_ARCH_GRANITE_MOE:
+            {
+                llm = std::make_unique<llm_build_granite>(*this, params, gf);
+            } break;
         case LLM_ARCH_CHAMELEON:
             {
                 llm = std::make_unique<llm_build_chameleon>(*this, params, gf);