]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add Deepseek MoE v1 & GigaChat models (#10827)
authorValentin Mamedov <redacted>
Sun, 15 Dec 2024 17:02:46 +0000 (00:02 +0700)
committerGitHub <redacted>
Sun, 15 Dec 2024 17:02:46 +0000 (19:02 +0200)
* Add deepseek v1 arch & gigachat template

* improve template code

* add readme

* delete comments

* remove comment

* fix format

* lint llama.cpp

* fix order of deepseek and deepseek2, move gigachat temlate to the end of func

* fix order of deepseek and deepseek2 in constants; mark shared exp as deepseek arch need

* remove comments

* move deepseek above deepseek2

* change placement of gigachat chat template

README.md
convert_hf_to_gguf.py
convert_hf_to_gguf_update.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama.cpp
tests/test-chat-template.cpp

index 49095acc8bbfc2a0cbb141b9c63d583f0c9effa0..a3b4917540a949a0f5664768d82edc5bdf926dc6 100644 (file)
--- a/README.md
+++ b/README.md
@@ -98,6 +98,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
 - [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
 - [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
 - [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
+- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
 
 #### Multimodal
 
index 206ab502ecadd627f180db6373df9a82c905cf11..9dc1673bc2c06cddaf1a24b38a14985f8162c23d 100755 (executable)
@@ -664,6 +664,9 @@ class Model:
         if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65":
             # ref: https://huggingface.co/sentence-transformers/stsb-roberta-base
             res = "roberta-bpe"
+        if chkhsh == "ad851be1dba641f2e3711822f816db2c265f788b37c63b4e1aeacb9ee92de8eb":
+            # ref: https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct
+            res = "gigachat"
 
         if res is None:
             logger.warning("\n")
@@ -3427,6 +3430,97 @@ class ArcticModel(Model):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
+@Model.register("DeepseekForCausalLM")
+class DeepseekModel(Model):
+    model_arch = gguf.MODEL_ARCH.DEEPSEEK
+
+    def set_vocab(self):
+        try:
+            self._set_vocab_sentencepiece()
+        except FileNotFoundError:
+            self._set_vocab_gpt2()
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        hparams = self.hparams
+        if "head_dim" in hparams:
+            rope_dim = hparams["head_dim"]
+        else:
+            rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
+
+        self.gguf_writer.add_rope_dimension_count(rope_dim)
+        self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+        self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
+        self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+        self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
+        self.gguf_writer.add_expert_weights_scale(1.0)
+        self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
+        self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
+
+    _experts: list[dict[str, Tensor]] | None = None
+
+    @staticmethod
+    def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
+        if n_head_kv is not None and n_head != n_head_kv:
+            n_head = n_head_kv
+        return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+                .swapaxes(1, 2)
+                .reshape(weights.shape))
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        n_head = self.hparams["num_attention_heads"]
+        n_kv_head = self.hparams.get("num_key_value_heads")
+
+        if name.endswith(("q_proj.weight", "q_proj.bias")):
+            data_torch = DeepseekModel.permute(data_torch, n_head, n_head)
+        if name.endswith(("k_proj.weight", "k_proj.bias")):
+            data_torch = DeepseekModel.permute(data_torch, n_head, n_kv_head)
+
+        # process the experts separately
+        if name.find("mlp.experts") != -1:
+            n_experts = self.hparams["n_routed_experts"]
+            assert bid is not None
+
+            if self._experts is None:
+                self._experts = [{} for _ in range(self.block_count)]
+
+            self._experts[bid][name] = data_torch
+
+            if len(self._experts[bid]) >= n_experts * 3:
+                tensors: list[tuple[str, Tensor]] = []
+
+                # merge the experts into a single 3d tensor
+                for w_name in ["down_proj", "gate_proj", "up_proj"]:
+                    datas: list[Tensor] = []
+
+                    for xid in range(n_experts):
+                        ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+                        datas.append(self._experts[bid][ename])
+                        del self._experts[bid][ename]
+
+                    data_torch = torch.stack(datas, dim=0)
+
+                    merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+
+                    new_name = self.map_tensor_name(merged_name)
+
+                    tensors.append((new_name, data_torch))
+                return tensors
+            else:
+                return []
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+    def prepare_tensors(self):
+        super().prepare_tensors()
+
+        if self._experts is not None:
+            # flatten `list[dict[str, Tensor]]` into `list[str]`
+            experts = [k for d in self._experts for k in d.keys()]
+            if len(experts) > 0:
+                raise ValueError(f"Unprocessed experts: {experts}")
+
+
 @Model.register("DeepseekV2ForCausalLM")
 class DeepseekV2Model(Model):
     model_arch = gguf.MODEL_ARCH.DEEPSEEK2
index aa10e5db796a0458bcfb14a6d1cbf195626e660a..88058442f6dc4ff1416e734df231a2b7befce42f 100755 (executable)
@@ -104,6 +104,7 @@ models = [
     {"name": "chameleon",      "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
     {"name": "minerva-7b",     "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", },
     {"name": "roberta-bpe",    "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
+    {"name": "gigachat",       "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
 ]
 
 
index 0cf93942916546256d7ce2610baed941057b27ac..c2c7cad14e500ae172a867e2175cdbcf4eeba883 100644 (file)
@@ -249,6 +249,7 @@ class MODEL_ARCH(IntEnum):
     OLMOE        = auto()
     OPENELM      = auto()
     ARCTIC       = auto()
+    DEEPSEEK     = auto()
     DEEPSEEK2    = auto()
     CHATGLM      = auto()
     BITNET       = auto()
@@ -412,6 +413,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.OLMOE:          "olmoe",
     MODEL_ARCH.OPENELM:        "openelm",
     MODEL_ARCH.ARCTIC:         "arctic",
+    MODEL_ARCH.DEEPSEEK:       "deepseek",
     MODEL_ARCH.DEEPSEEK2:      "deepseek2",
     MODEL_ARCH.CHATGLM:        "chatglm",
     MODEL_ARCH.BITNET:         "bitnet",
@@ -1158,6 +1160,29 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
     ],
+    MODEL_ARCH.DEEPSEEK: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_GATE_SHEXP,
+        MODEL_TENSOR.FFN_DOWN_SHEXP,
+        MODEL_TENSOR.FFN_UP_SHEXP,
+    ],
     MODEL_ARCH.DEEPSEEK2: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
@@ -1380,6 +1405,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ROPE_FREQS,
         MODEL_TENSOR.ATTN_ROT_EMBD,
     ],
+    MODEL_ARCH.DEEPSEEK: [
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+    ],
     MODEL_ARCH.DEEPSEEK2: [
         MODEL_TENSOR.ROPE_FREQS,
         MODEL_TENSOR.ATTN_ROT_EMBD,
index f0a7b6478508efdce02de2dc65492e8c6e603526..573d0282ea599abb63523a9fbd24b9fe9dba5669 100644 (file)
@@ -306,7 +306,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.FFN_UP_SHEXP: (
             "model.layers.{bid}.mlp.shared_expert.up_proj",  # qwen2moe
-            "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
+            "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
         ),
 
         # AWQ-activation gate
@@ -338,7 +338,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.FFN_GATE_SHEXP: (
             "model.layers.{bid}.mlp.shared_expert.gate_proj",  # qwen2moe
-            "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2
+            "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
         ),
 
         # Feed-forward down
@@ -379,7 +379,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.FFN_DOWN_SHEXP: (
             "model.layers.{bid}.mlp.shared_expert.down_proj",  # qwen2moe
-            "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2
+            "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
         ),
 
         MODEL_TENSOR.ATTN_Q_NORM: (
index abc1252e72819955740bf2ffb09bb5ddcbd9b4ae..8b799e0ebeda72e86c2cf6b296f23375756ce338 100644 (file)
@@ -184,6 +184,7 @@ enum llm_arch {
     LLM_ARCH_OLMOE,
     LLM_ARCH_OPENELM,
     LLM_ARCH_ARCTIC,
+    LLM_ARCH_DEEPSEEK,
     LLM_ARCH_DEEPSEEK2,
     LLM_ARCH_CHATGLM,
     LLM_ARCH_BITNET,
@@ -239,6 +240,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_OLMOE,           "olmoe"        },
     { LLM_ARCH_OPENELM,         "openelm"      },
     { LLM_ARCH_ARCTIC,          "arctic"       },
+    { LLM_ARCH_DEEPSEEK,        "deepseek"     },
     { LLM_ARCH_DEEPSEEK2,       "deepseek2"    },
     { LLM_ARCH_CHATGLM,         "chatglm"      },
     { LLM_ARCH_BITNET,          "bitnet"       },
@@ -1309,6 +1311,33 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_DEEPSEEK,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ROPE_FREQS,         "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,      "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
+            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
+            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
+            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
+        },
+    },
     {
         LLM_ARCH_DEEPSEEK2,
         {
@@ -1600,6 +1629,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_EXAONE_3,
     LLM_CHAT_TEMPLATE_RWKV_WORLD,
     LLM_CHAT_TEMPLATE_GRANITE,
+    LLM_CHAT_TEMPLATE_GIGACHAT,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 
@@ -1631,6 +1661,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "exaone3",           LLM_CHAT_TEMPLATE_EXAONE_3          },
     { "rwkv-world",        LLM_CHAT_TEMPLATE_RWKV_WORLD        },
     { "granite",           LLM_CHAT_TEMPLATE_GRANITE           },
+    { "gigachat",          LLM_CHAT_TEMPLATE_GIGACHAT          },
 };
 
 static llm_arch llm_arch_from_string(const std::string & name) {
@@ -6094,6 +6125,19 @@ static void llm_load_hparams(
                     model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_DEEPSEEK:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
+
+                switch (hparams.n_layer) {
+                    case 28: model.type = e_model::MODEL_20B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_DEEPSEEK2:
             {
                 bool is_lite = (hparams.n_layer == 27);
@@ -6440,6 +6484,7 @@ static void llm_load_vocab(
                     tokenizer_pre == "phi-2"   ||
                     tokenizer_pre == "jina-es" ||
                     tokenizer_pre == "jina-de" ||
+                    tokenizer_pre == "gigachat"   ||
                     tokenizer_pre == "jina-v1-en" ||
                     tokenizer_pre == "jina-v2-es" ||
                     tokenizer_pre == "jina-v2-de" ||
@@ -7091,6 +7136,13 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
 
     LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
 
+    if (model.arch == LLM_ARCH_DEEPSEEK) {
+        LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
+        LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
+        LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
+        LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
+    }
+
     if (model.arch == LLM_ARCH_DEEPSEEK2) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
@@ -8865,6 +8917,55 @@ static bool llm_load_tensors(
                         layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
                     }
                 } break;
+            case LLM_ARCH_DEEPSEEK:
+                {
+
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
+
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (i < (int) hparams.n_layer_dense_lead) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        } else {
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
+
+                            // MoE branch
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                            // Shared expert branch
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                        }
+                    }
+                } break;
             case LLM_ARCH_DEEPSEEK2:
                 {
                     const bool is_lite = (hparams.n_layer == 27);
@@ -15219,6 +15320,161 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_deepseek() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                ggml_tensor * moe_out =
+                        llm_build_moe_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_gate_inp,
+                            model.layers[il].ffn_up_exps,
+                            model.layers[il].ffn_gate_exps,
+                            model.layers[il].ffn_down_exps,
+                            n_expert, n_expert_used,
+                            LLM_FFN_SILU, false,
+                            false, hparams.expert_weights_scale,
+                            cb, il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                // FFN shared expert
+                {
+                    ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_up_shexp,   NULL, NULL,
+                            model.layers[il].ffn_gate_shexp, NULL, NULL,
+                            model.layers[il].ffn_down_shexp, NULL, NULL,
+                            NULL,
+                            LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                    cb(cur, "ffn_out", il);
+                }
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_deepseek2() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
@@ -16906,6 +17162,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_arctic();
             } break;
+        case LLM_ARCH_DEEPSEEK:
+            {
+                result = llm.build_deepseek();
+            } break;
         case LLM_ARCH_DEEPSEEK2:
             {
                 result = llm.build_deepseek2();
@@ -20137,6 +20397,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_COMMAND_R:
         case LLM_ARCH_OLMO:
         case LLM_ARCH_ARCTIC:
+        case LLM_ARCH_DEEPSEEK:
         case LLM_ARCH_DEEPSEEK2:
         case LLM_ARCH_CHATGLM:
         case LLM_ARCH_GRANITE:
@@ -22002,6 +22263,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_RWKV_WORLD;
     } else if (tmpl_contains("<|start_of_role|>")) {
         return LLM_CHAT_TEMPLATE_GRANITE;
+    } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) {
+        return LLM_CHAT_TEMPLATE_GIGACHAT;
     }
     return LLM_CHAT_TEMPLATE_UNKNOWN;
 }
@@ -22325,6 +22588,32 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|start_of_role|>assistant<|end_of_role|>\n";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
+        // GigaChat template
+        bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
+
+        // Handle system message if present
+        if (has_system) {
+            ss << "<s>" << chat[0]->content << "<|message_sep|>";
+        } else {
+            ss << "<s>";
+        }
+
+        // Process remaining messages
+        for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) {
+            std::string role(chat[i]->role);
+            if (role == "user") {
+                ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>"
+                << "available functions<|role_sep|>[]<|message_sep|>";
+            } else if (role == "assistant") {
+                ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>";
+            }
+        }
+
+        // Add generation prompt if needed
+        if (add_ass) {
+            ss << "assistant<|role_sep|>";
+        }
     } else {
         // template not supported
         return -1;
index aa140b5696f743a9bdd7af1ca5eb3d18cf0e483e..30a910ad5c55df1a223380dea0aff15acdf199d7 100644 (file)
@@ -75,6 +75,8 @@ int main(void) {
         "{%- if messages[0][\"role\"] == \"system\" %}\n    {%- set system_message = messages[0][\"content\"] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n    {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n        {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n            {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n        {%- endif %}\n        {%- set ns.index = ns.index + 1 %}\n    {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if message[\"role\"] == \"user\" %}\n        {%- if tools is not none and (message == user_messages[-1]) %}\n            {{- \"[AVAILABLE_TOOLS][\" }}\n            {%- for tool in tools %}\n                {%- set tool = tool.function %}\n                {{- '{\"type\": \"function\", \"function\": {' }}\n                {%- for key, val in tool.items() if key != \"return\" %}\n                    {%- if val is string %}\n                        {{- '\"' + key + '\": \"' + val + '\"' }}\n                    {%- else %}\n                        {{- '\"' + key + '\": ' + val|tojson }}\n                    {%- endif %}\n                    {%- if not loop.last %}\n                        {{- \", \" }}\n                    {%- endif %}\n                {%- endfor %}\n                {{- \"}}\" }}\n                {%- if not loop.last %}\n                    {{- \", \" }}\n                {%- else %}\n                    {{- \"]\" }}\n                {%- endif %}\n            {%- endfor %}\n            {{- \"[/AVAILABLE_TOOLS]\" }}\n            {%- endif %}\n        {%- if loop.last and system_message is defined %}\n            {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n        {%- else %}\n            {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n        {%- endif %}\n    {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n        {{- \"[TOOL_CALLS][\" }}\n        {%- for tool_call in message.tool_calls %}\n            {%- set out = tool_call.function|tojson %}\n            {{- out[:-1] }}\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n                {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n            {%- endif %}\n            {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n            {%- if not loop.last %}\n                {{- \", \" }}\n            {%- else %}\n                {{- \"]\" + eos_token }}\n            {%- endif %}\n        {%- endfor %}\n    {%- elif message[\"role\"] == \"assistant\" %}\n        {{- message[\"content\"] + eos_token}}\n    {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n        {%- if message.content is defined and message.content.content is defined %}\n            {%- set content = message.content.content %}\n        {%- else %}\n            {%- set content = message.content %}\n        {%- endif %}\n        {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n            {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n        {%- endif %}\n        {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n    {%- else %}\n        {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n    {%- endif %}\n{%- endfor %}\n",
         // mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)
         "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}",
+        // ai-sage/GigaChat-20B-A3B-instruct
+        "{% if messages[0]['role'] == 'system' -%}\n    {%- set loop_messages = messages[1:] -%}\n    {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n    {%- set loop_messages = messages -%}\n    {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n    {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n    {% endif %}\n    \n    {%- if loop.index0 == 0 -%}\n        {{ system_message -}}\n    {%- endif -%}\n    {%- if message['role'] == 'user' -%}\n        {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n        {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3]  + additional_special_tokens[1] -}}\n    {%- endif -%}\n    {%- if message['role'] == 'assistant' -%}\n        {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n    {%- endif -%}\n    {%- if loop.last and add_generation_prompt -%}\n        {{ 'assistant' + additional_special_tokens[0] -}}\n    {%- endif -%}\n{%- endfor %}",
     };
     std::vector<std::string> expected_output = {
         // teknium/OpenHermes-2.5-Mistral-7B
@@ -129,6 +131,8 @@ int main(void) {
         "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST]   I am an assistant   </s>[INST]Another question[/INST]",
         // mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)
         "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST]    I am an assistant   </s>[INST] Another question[/INST]",
+        // ai-sage/GigaChat-20B-A3B-instruct
+        "<s>You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>   I am an assistant   <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
     };
     std::vector<char> formatted_chat(1024);
     int32_t res;
@@ -190,6 +194,7 @@ int main(void) {
     assert(fmt_sys("mistral") == "[INST] You are a helpful assistant\n"); // for old pre-v1 templates
     assert(fmt_sys("gemma")  == ""); // for gemma, system message is merged with user message
     assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
+    assert(fmt_sys("gigachat") == "<s>You are a helpful assistant<|message_sep|>");
 
 
     // test llama_chat_format_single for user message
@@ -214,6 +219,7 @@ int main(void) {
     assert(fmt_single("mistral") == "[INST] How are you [/INST]"); // for old pre-v1 templates
     assert(fmt_single("gemma")  == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
     assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
+    assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
 
     printf("Test chat templates: OK\n");