if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65":
# ref: https://huggingface.co/sentence-transformers/stsb-roberta-base
res = "roberta-bpe"
+ if chkhsh == "ad851be1dba641f2e3711822f816db2c265f788b37c63b4e1aeacb9ee92de8eb":
+ # ref: https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct
+ res = "gigachat"
if res is None:
logger.warning("\n")
raise ValueError(f"Unprocessed experts: {experts}")
+@Model.register("DeepseekForCausalLM")
+class DeepseekModel(Model):
+ model_arch = gguf.MODEL_ARCH.DEEPSEEK
+
+ def set_vocab(self):
+ try:
+ self._set_vocab_sentencepiece()
+ except FileNotFoundError:
+ self._set_vocab_gpt2()
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ hparams = self.hparams
+ if "head_dim" in hparams:
+ rope_dim = hparams["head_dim"]
+ else:
+ rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
+
+ self.gguf_writer.add_rope_dimension_count(rope_dim)
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+ self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
+ self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+ self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
+ self.gguf_writer.add_expert_weights_scale(1.0)
+ self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
+ self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
+
+ _experts: list[dict[str, Tensor]] | None = None
+
+ @staticmethod
+ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
+ if n_head_kv is not None and n_head != n_head_kv:
+ n_head = n_head_kv
+ return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+ .swapaxes(1, 2)
+ .reshape(weights.shape))
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ n_head = self.hparams["num_attention_heads"]
+ n_kv_head = self.hparams.get("num_key_value_heads")
+
+ if name.endswith(("q_proj.weight", "q_proj.bias")):
+ data_torch = DeepseekModel.permute(data_torch, n_head, n_head)
+ if name.endswith(("k_proj.weight", "k_proj.bias")):
+ data_torch = DeepseekModel.permute(data_torch, n_head, n_kv_head)
+
+ # process the experts separately
+ if name.find("mlp.experts") != -1:
+ n_experts = self.hparams["n_routed_experts"]
+ assert bid is not None
+
+ if self._experts is None:
+ self._experts = [{} for _ in range(self.block_count)]
+
+ self._experts[bid][name] = data_torch
+
+ if len(self._experts[bid]) >= n_experts * 3:
+ tensors: list[tuple[str, Tensor]] = []
+
+ # merge the experts into a single 3d tensor
+ for w_name in ["down_proj", "gate_proj", "up_proj"]:
+ datas: list[Tensor] = []
+
+ for xid in range(n_experts):
+ ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+ datas.append(self._experts[bid][ename])
+ del self._experts[bid][ename]
+
+ data_torch = torch.stack(datas, dim=0)
+
+ merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+
+ new_name = self.map_tensor_name(merged_name)
+
+ tensors.append((new_name, data_torch))
+ return tensors
+ else:
+ return []
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ def prepare_tensors(self):
+ super().prepare_tensors()
+
+ if self._experts is not None:
+ # flatten `list[dict[str, Tensor]]` into `list[str]`
+ experts = [k for d in self._experts for k in d.keys()]
+ if len(experts) > 0:
+ raise ValueError(f"Unprocessed experts: {experts}")
+
+
@Model.register("DeepseekV2ForCausalLM")
class DeepseekV2Model(Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
LLM_ARCH_OLMOE,
LLM_ARCH_OPENELM,
LLM_ARCH_ARCTIC,
+ LLM_ARCH_DEEPSEEK,
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_CHATGLM,
LLM_ARCH_BITNET,
{ LLM_ARCH_OLMOE, "olmoe" },
{ LLM_ARCH_OPENELM, "openelm" },
{ LLM_ARCH_ARCTIC, "arctic" },
+ { LLM_ARCH_DEEPSEEK, "deepseek" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
+ {
+ LLM_ARCH_DEEPSEEK,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
+ { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
+ },
+ },
{
LLM_ARCH_DEEPSEEK2,
{
LLM_CHAT_TEMPLATE_EXAONE_3,
LLM_CHAT_TEMPLATE_RWKV_WORLD,
LLM_CHAT_TEMPLATE_GRANITE,
+ LLM_CHAT_TEMPLATE_GIGACHAT,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
+ { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
};
static llm_arch llm_arch_from_string(const std::string & name) {
model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_DEEPSEEK:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
+
+ switch (hparams.n_layer) {
+ case 28: model.type = e_model::MODEL_20B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_DEEPSEEK2:
{
bool is_lite = (hparams.n_layer == 27);
tokenizer_pre == "phi-2" ||
tokenizer_pre == "jina-es" ||
tokenizer_pre == "jina-de" ||
+ tokenizer_pre == "gigachat" ||
tokenizer_pre == "jina-v1-en" ||
tokenizer_pre == "jina-v2-es" ||
tokenizer_pre == "jina-v2-de" ||
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
+ if (model.arch == LLM_ARCH_DEEPSEEK) {
+ LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
+ LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
+ }
+
if (model.arch == LLM_ARCH_DEEPSEEK2) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
}
} break;
+ case LLM_ARCH_DEEPSEEK:
+ {
+
+ const int64_t n_ff_exp = hparams.n_ff_exp;
+ const int64_t n_expert_shared = hparams.n_expert_shared;
+
+ model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+ if (i < (int) hparams.n_layer_dense_lead) {
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ } else {
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+ if (n_expert == 0) {
+ throw std::runtime_error("n_expert must be > 0");
+ }
+ if (n_expert_used == 0) {
+ throw std::runtime_error("n_expert_used must be > 0");
+ }
+
+ // MoE branch
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
+
+ // Shared expert branch
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+ }
+ }
+ } break;
case LLM_ARCH_DEEPSEEK2:
{
const bool is_lite = (hparams.n_layer == 27);
return gf;
}
+ struct ggml_cgraph * build_deepseek() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
+ struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
+
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
+
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ Qcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ model.layers[il].wo, model.layers[il].bo,
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ if ((uint32_t) il < hparams.n_layer_dense_lead) {
+ cur = llm_build_ffn(ctx0, lctx, cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+ } else {
+ // MoE branch
+ ggml_tensor * moe_out =
+ llm_build_moe_ffn(ctx0, lctx, cur,
+ model.layers[il].ffn_gate_inp,
+ model.layers[il].ffn_up_exps,
+ model.layers[il].ffn_gate_exps,
+ model.layers[il].ffn_down_exps,
+ n_expert, n_expert_used,
+ LLM_FFN_SILU, false,
+ false, hparams.expert_weights_scale,
+ cb, il);
+ cb(moe_out, "ffn_moe_out", il);
+
+ // FFN shared expert
+ {
+ ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
+ model.layers[il].ffn_up_shexp, NULL, NULL,
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
+ model.layers[il].ffn_down_shexp, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(ffn_shexp, "ffn_shexp", il);
+
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
+ cb(cur, "ffn_out", il);
+ }
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
struct ggml_cgraph * build_deepseek2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
{
result = llm.build_arctic();
} break;
+ case LLM_ARCH_DEEPSEEK:
+ {
+ result = llm.build_deepseek();
+ } break;
case LLM_ARCH_DEEPSEEK2:
{
result = llm.build_deepseek2();
case LLM_ARCH_COMMAND_R:
case LLM_ARCH_OLMO:
case LLM_ARCH_ARCTIC:
+ case LLM_ARCH_DEEPSEEK:
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
} else if (tmpl_contains("<|start_of_role|>")) {
return LLM_CHAT_TEMPLATE_GRANITE;
+ } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) {
+ return LLM_CHAT_TEMPLATE_GIGACHAT;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
if (add_ass) {
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
}
+ } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
+ // GigaChat template
+ bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
+
+ // Handle system message if present
+ if (has_system) {
+ ss << "<s>" << chat[0]->content << "<|message_sep|>";
+ } else {
+ ss << "<s>";
+ }
+
+ // Process remaining messages
+ for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) {
+ std::string role(chat[i]->role);
+ if (role == "user") {
+ ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>"
+ << "available functions<|role_sep|>[]<|message_sep|>";
+ } else if (role == "assistant") {
+ ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>";
+ }
+ }
+
+ // Add generation prompt if needed
+ if (add_ass) {
+ ss << "assistant<|role_sep|>";
+ }
} else {
// template not supported
return -1;