special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
+ def _set_vocab_glm(self):
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
+ special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
+ tokens, toktypes, tokpre = self.get_vocab_base()
+ self.gguf_writer.add_tokenizer_model("gpt2")
+ self.gguf_writer.add_tokenizer_pre(tokpre)
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+ # Special tokens
+ # Note: Using <|endoftext|> (151329) for eot causes endless generation
+ special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
+ special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
+ special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
+ special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
+ special_vocab.add_to_gguf(self.gguf_writer)
+
def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
+ # TODO @ngxson : remove this when we support MTP for deepseek models
+ skip_mtp = True
+
def set_vocab(self):
try:
self._set_vocab_gpt2()
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
# skip Multi-Token Prediction (MTP) layers
- block_count = self.hparams["num_hidden_layers"]
- match = re.match(r"model.layers.(\d+)", name)
- if match and int(match.group(1)) >= block_count:
- return
+ if self.skip_mtp:
+ block_count = self.hparams["num_hidden_layers"]
+ match = re.match(r"model.layers.(\d+)", name)
+ if match and int(match.group(1)) >= block_count:
+ return
# process the experts separately
if name.find("mlp.experts") != -1:
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_vocab(self):
- from transformers import AutoTokenizer
-
- tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
- special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
- tokens, toktypes, tokpre = self.get_vocab_base()
- self.gguf_writer.add_tokenizer_model("gpt2")
- self.gguf_writer.add_tokenizer_pre(tokpre)
- self.gguf_writer.add_token_list(tokens)
- self.gguf_writer.add_token_types(toktypes)
-
- # Special tokens
- # Note: Using <|endoftext|> (151329) for eot causes endless generation
- special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
- special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
- special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
- special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
-
- special_vocab.add_to_gguf(self.gguf_writer)
+ return self._set_vocab_glm()
def set_gguf_parameters(self):
super().set_gguf_parameters()
class Glm4MoeLiteModel(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
- # copied from Glm4MoeModel
def set_vocab(self):
- from transformers import AutoTokenizer
+ return self._set_vocab_glm()
- tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
- special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
- tokens, toktypes, tokpre = self.get_vocab_base()
- self.gguf_writer.add_tokenizer_model("gpt2")
- self.gguf_writer.add_tokenizer_pre(tokpre)
- self.gguf_writer.add_token_list(tokens)
- self.gguf_writer.add_token_types(toktypes)
- # Special tokens
- # Note: Using <|endoftext|> (151329) for eot causes endless generation
- special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
- special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
- special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
- special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
+@ModelBase.register("GlmMoeDsaForCausalLM")
+class GlmMoeDsaModel(DeepseekV2Model):
+ model_arch = gguf.MODEL_ARCH.GLM_DSA
+ skip_mtp = False
- special_vocab.add_to_gguf(self.gguf_writer)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
+ self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+
+ def set_vocab(self):
+ return self._set_vocab_glm()
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+
+ rope_dim = self.hparams["qk_rope_head_dim"]
+ partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0)
+ self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor))
+
+ # NextN/MTP prediction layers
+ if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
+ self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
+
+ # DSA indexer parameters
+ self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
+ self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
+ self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
+ class Indexer:
+ HEAD_COUNT = "{arch}.attention.indexer.head_count"
+ KEY_LENGTH = "{arch}.attention.indexer.key_length"
+ TOP_K = "{arch}.attention.indexer.top_k"
+
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
CHATGLM = auto()
GLM4 = auto()
GLM4_MOE = auto()
+ GLM_DSA = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
VISEXP_GATE = auto()
VISEXP_DOWN = auto()
VISEXP_UP = auto()
+ INDEXER_K_NORM = auto()
+ INDEXER_PROJ = auto()
+ INDEXER_ATTN_K = auto()
+ INDEXER_ATTN_Q_B = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4moe",
+ MODEL_ARCH.GLM_DSA: "glm-dsa",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_TENSOR.VISEXP_GATE: "blk.{bid}.vis_gate",
MODEL_TENSOR.VISEXP_DOWN: "blk.{bid}.vis_down",
MODEL_TENSOR.VISEXP_UP: "blk.{bid}.vis_up",
+ MODEL_TENSOR.INDEXER_K_NORM: "blk.{bid}.indexer.k_norm",
+ MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj",
+ MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k",
+ MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b",
# vision
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
+ MODEL_ARCH.GLM_DSA: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_A,
+ MODEL_TENSOR.ATTN_Q_B,
+ MODEL_TENSOR.ATTN_KV_A_MQA,
+ MODEL_TENSOR.ATTN_KV_B,
+ MODEL_TENSOR.ATTN_K_B,
+ MODEL_TENSOR.ATTN_V_B,
+ MODEL_TENSOR.ATTN_Q_A_NORM,
+ MODEL_TENSOR.ATTN_KV_A_NORM,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ MODEL_TENSOR.INDEXER_K_NORM,
+ MODEL_TENSOR.INDEXER_PROJ,
+ MODEL_TENSOR.INDEXER_ATTN_K,
+ MODEL_TENSOR.INDEXER_ATTN_Q_B,
+ # NextN/MTP tensors - preserved but unused
+ MODEL_TENSOR.NEXTN_EH_PROJ,
+ MODEL_TENSOR.NEXTN_EMBED_TOKENS,
+ MODEL_TENSOR.NEXTN_ENORM,
+ MODEL_TENSOR.NEXTN_HNORM,
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
+ ],
MODEL_ARCH.BITNET: [
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
def add_value_length_mla(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
+ def add_indexer_head_count(self, count: int) -> None:
+ self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count)
+
+ def add_indexer_key_length(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.Indexer.KEY_LENGTH.format(arch=self.arch), length)
+
+ def add_indexer_top_k(self, top_k: int) -> None:
+ self.add_uint32(Keys.Attention.Indexer.TOP_K.format(arch=self.arch), top_k)
+
def add_max_alibi_bias(self, bias: float) -> None:
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
"model.layers.{bid}.self_attn.vision_expert_query_key_value", # cogvlm
),
+ MODEL_TENSOR.INDEXER_K_NORM: (
+ "model.layers.{bid}.self_attn.indexer.k_norm", # DSA
+ ),
+
+ MODEL_TENSOR.INDEXER_PROJ: (
+ "model.layers.{bid}.self_attn.indexer.weights_proj", # DSA
+ ),
+
+ MODEL_TENSOR.INDEXER_ATTN_K: (
+ "model.layers.{bid}.self_attn.indexer.wk", # DSA
+ ),
+
+ MODEL_TENSOR.INDEXER_ATTN_Q_B: (
+ "model.layers.{bid}.self_attn.indexer.wq_b", # DSA
+ ),
+
############################################################################
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
MODEL_TENSOR.ENC_OUTPUT_NORM: (
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
+ { LLM_ARCH_GLM_DSA, "glm-dsa" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
+ { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
+ { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
+ { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" },
{ LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" },
{ LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" },
+ { LLM_TENSOR_INDEXER_K_NORM, "blk.%d.indexer.k_norm" },
+ { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
+ { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
+ { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
};
static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
};
+ case LLM_ARCH_GLM_DSA:
+ return {
+ LLM_TENSOR_TOKEN_EMBD,
+ LLM_TENSOR_OUTPUT_NORM,
+ LLM_TENSOR_OUTPUT,
+ LLM_TENSOR_ATTN_NORM,
+ LLM_TENSOR_ATTN_Q_A_NORM,
+ LLM_TENSOR_ATTN_KV_A_NORM,
+ LLM_TENSOR_ATTN_Q,
+ LLM_TENSOR_ATTN_Q_A,
+ LLM_TENSOR_ATTN_Q_B,
+ LLM_TENSOR_ATTN_KV_A_MQA,
+ LLM_TENSOR_ATTN_KV_B,
+ LLM_TENSOR_ATTN_K_B,
+ LLM_TENSOR_ATTN_V_B,
+ LLM_TENSOR_ATTN_OUT,
+ LLM_TENSOR_FFN_NORM,
+ LLM_TENSOR_FFN_GATE,
+ LLM_TENSOR_FFN_UP,
+ LLM_TENSOR_FFN_DOWN,
+ LLM_TENSOR_FFN_GATE_INP,
+ LLM_TENSOR_FFN_GATE_EXPS,
+ LLM_TENSOR_FFN_DOWN_EXPS,
+ LLM_TENSOR_FFN_UP_EXPS,
+ LLM_TENSOR_FFN_GATE_INP_SHEXP,
+ LLM_TENSOR_FFN_GATE_SHEXP,
+ LLM_TENSOR_FFN_DOWN_SHEXP,
+ LLM_TENSOR_FFN_UP_SHEXP,
+ LLM_TENSOR_FFN_EXP_PROBS_B,
+ LLM_TENSOR_INDEXER_K_NORM,
+ LLM_TENSOR_INDEXER_PROJ,
+ LLM_TENSOR_INDEXER_ATTN_K,
+ LLM_TENSOR_INDEXER_ATTN_Q_B,
+ LLM_TENSOR_NEXTN_EH_PROJ,
+ LLM_TENSOR_NEXTN_EMBED_TOKENS,
+ LLM_TENSOR_NEXTN_ENORM,
+ LLM_TENSOR_NEXTN_HNORM,
+ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
+ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
+ };
case LLM_ARCH_BITNET:
return {
LLM_TENSOR_TOKEN_EMBD,
{LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+ {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
+ LLM_ARCH_GLM_DSA,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
+ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
+ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
+ LLM_KV_ATTENTION_INDEXER_TOP_K,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
LLM_TENSOR_VISEXP_FFN_GATE,
LLM_TENSOR_VISEXP_FFN_DOWN,
LLM_TENSOR_VISEXP_FFN_UP,
+ LLM_TENSOR_INDEXER_K_NORM,
+ LLM_TENSOR_INDEXER_PROJ,
+ LLM_TENSOR_INDEXER_ATTN_K,
+ LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
std::array<float, LLAMA_MAX_LAYERS> xielu_eps;
+ // DSA (deepseek sparse attention)
+ uint32_t indexer_n_head = 0;
+ uint32_t indexer_head_size = 0;
+ uint32_t indexer_top_k = 0;
+
// qwen3vl deepstack
uint32_t n_deepstack_layers = 0;
case LLM_TYPE_300B_A47B: return "300B.A47B";
case LLM_TYPE_310B_A15B: return "310B.A15B";
case LLM_TYPE_355B_A32B: return "355B.A32B";
+ case LLM_TYPE_744B_A40B: return "744B.A40B";
case LLM_TYPE_E2B: return "E2B";
case LLM_TYPE_E4B: return "E4B";
default: return "?B";
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_GLM_DSA:
+ {
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
+
+ // MoE parameters
+ ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert);
+ ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used);
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
+ ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false);
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
+
+ // deepseek MLA parameters
+ ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
+ ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
+ ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false);
+ ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false);
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
+
+ // DSA parameters
+ ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head);
+ ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size);
+ ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k);
+
+ // Expert gating function (GLM-4.5 uses sigmoid)
+ ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
+ if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
+ hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
+ }
+
+ // NextN/MTP parameters
+ ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
+
+ // TODO: when MTP is implemented, this should probably be updated if needed
+ hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
+
+ switch (hparams.n_layer) {
+ case 79: type = LLM_TYPE_744B_A40B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_BITNET:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
}
}
break;
+ case LLM_ARCH_GLM_DSA:
+ {
+ const bool is_mla = hparams.is_mla();
+ if (!is_mla) {
+ throw std::runtime_error("GLM_DSA architecture requires MLA");
+ }
+
+ // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
+ const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
+ const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
+
+ const int64_t n_embd_head_qk_rope = hparams.n_rot;
+ const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
+
+ const int64_t q_lora_rank = hparams.n_lora_q;
+ const int64_t kv_lora_rank = hparams.n_lora_kv;
+
+ const int64_t n_ff_exp = hparams.n_ff_exp;
+ const int64_t n_expert_shared = hparams.n_expert_shared;
+
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ // try to load output.weight, if not found, use token_embd (tied embeddings)
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+ if (!output) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ int flags = 0;
+ if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
+ // skip all tensors in the NextN layers
+ // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later
+ flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED;
+ }
+
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
+ layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags);
+ layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags);
+
+ layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags);
+ layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags);
+
+ layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags);
+
+ // note: only old legacy GGUF files will have the unsplit wkv_b tensor in
+ layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags);
+ layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags);
+
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
+
+ // DSA indexer
+ layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags);
+ layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags);
+ layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags);
+ layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags);
+ layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags);
+ if (i < (int) hparams.n_layer_dense_lead) {
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
+ } else {
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags);
+ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+
+ if (n_expert == 0) {
+ throw std::runtime_error("n_expert must be > 0");
+ }
+ if (n_expert_used == 0) {
+ throw std::runtime_error("n_expert_used must be > 0");
+ }
+
+ // MoE branch
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags);
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
+
+ // Shared expert branch
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags);
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags);
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags);
+ }
+
+ // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
+ if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
+ layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
+ layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
+ layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
+
+ // Optional tensors
+ layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
+ layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED);
+ }
+ }
+ } break;
case LLM_ARCH_NEMOTRON:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
}
- if (arch == LLM_ARCH_DEEPSEEK2) {
+ if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
llm = std::make_unique<llm_build_deepseek>(*this, params);
} break;
case LLM_ARCH_DEEPSEEK2:
+ case LLM_ARCH_GLM_DSA:
{
llm = std::make_unique<llm_build_deepseek2>(*this, params);
} break;
case LLM_ARCH_MISTRAL3:
case LLM_ARCH_LLAMA_EMBED:
case LLM_ARCH_MAINCODER:
+ case LLM_ARCH_GLM_DSA:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2
LLM_TYPE_300B_A47B, // Ernie MoE big
LLM_TYPE_310B_A15B, // /MiMo-V2-Flash
LLM_TYPE_355B_A32B, // GLM-4.5
+ LLM_TYPE_744B_A40B, // GLM-5
LLM_TYPE_E2B,
LLM_TYPE_E4B,
};
struct ggml_tensor * ssm_g_b = nullptr;
struct ggml_tensor * ssm_o_norm = nullptr;
+ // DSA (deepseek sparse attention)
+ struct ggml_tensor * indexer_k_norm = nullptr;
+ struct ggml_tensor * indexer_k_norm_b = nullptr;
+ struct ggml_tensor * indexer_proj = nullptr;
+ struct ggml_tensor * indexer_attn_k = nullptr;
+ struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias
+
struct llama_layer_posnet posnet;
struct llama_layer_convnext convnext;
ggml_tensor * inp_out_ids = build_inp_out_ids();
- for (int il = 0; il < n_layer; ++il) {
+ int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers;
+ for (int il = 0; il < effective_n_layers; ++il) {
ggml_tensor * inpSA = inpL;
// norm
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
}
}
- if (il == n_layer - 1 && inp_out_ids) {
+ if (il == effective_n_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}