]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add support for ArcticForCausalLM (#7020)
authorfairydreaming <redacted>
Fri, 24 May 2024 12:31:13 +0000 (14:31 +0200)
committerGitHub <redacted>
Fri, 24 May 2024 12:31:13 +0000 (14:31 +0200)
* common : increase max number of experts to 128

* common : add tensor LLM_TENSOR_FFN_NORM_EXPS for normalization before MoE that runs in parallel to attention + ffn

* gguf-py : add architecture-specific block mappings that override selected general block mappings

* convert-hf : add model conversion support for ArcticForCausalLM

* convert-hf : use added_tokens_decoder from tokenizer_config.json to redefine tokens from SentencePiece model (only for ArcticForCausalLM)

* llama : add inference support for LLM_ARCH_ARCTIC

---------

Co-authored-by: Stanisław Szymczyk <redacted>
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp

index 5a00a5e89accbd6b0bf749c8d10c05f578b3275b..998877c26da1966b11ec3446178e64b20c47b575 100755 (executable)
@@ -2466,6 +2466,157 @@ class JinaBertV2Model(BertModel):
         self.gguf_writer.add_add_eos_token(True)
 
 
+@Model.register("ArcticForCausalLM")
+class ArcticModel(Model):
+    model_arch = gguf.MODEL_ARCH.ARCTIC
+
+    def set_vocab(self):
+        # The reason for using a custom implementation here is that the
+        # snowflake-arctic-instruct model redefined tokens 31998 and 31999 from
+        # tokenizer.model and used them as BOS and EOS instead of adding new tokens.
+        from sentencepiece import SentencePieceProcessor
+
+        tokenizer_path = self.dir_model / 'tokenizer.model'
+
+        if not tokenizer_path.is_file():
+            logger.error(f'Error: Missing {tokenizer_path}')
+            sys.exit(1)
+
+        # Read the whole vocabulary from the tokenizer.model file
+        tokenizer = SentencePieceProcessor()
+        tokenizer.LoadFromFile(str(tokenizer_path))
+
+        vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
+
+        tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
+        scores: list[float] = [-10000.0] * vocab_size
+        toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
+
+        for token_id in range(tokenizer.vocab_size()):
+
+            piece = tokenizer.IdToPiece(token_id)
+            text = piece.encode("utf-8")
+            score = tokenizer.GetScore(token_id)
+
+            toktype = SentencePieceTokenTypes.NORMAL
+            if tokenizer.IsUnknown(token_id):
+                toktype = SentencePieceTokenTypes.UNKNOWN
+            elif tokenizer.IsControl(token_id):
+                toktype = SentencePieceTokenTypes.CONTROL
+            elif tokenizer.IsUnused(token_id):
+                toktype = SentencePieceTokenTypes.UNUSED
+            elif tokenizer.IsByte(token_id):
+                toktype = SentencePieceTokenTypes.BYTE
+
+            tokens[token_id] = text
+            scores[token_id] = score
+            toktypes[token_id] = toktype
+
+        # Use the added_tokens_decoder field from tokeniser_config.json as the source
+        # of information about added/redefined tokens and modify them accordingly.
+        tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
+        if tokenizer_config_file.is_file():
+            with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+                tokenizer_config_json = json.load(f)
+
+                if "added_tokens_decoder" in tokenizer_config_json:
+                    added_tokens_decoder = tokenizer_config_json["added_tokens_decoder"]
+                    for token_id, token_json in added_tokens_decoder.items():
+                        token_id = int(token_id)
+                        if (token_id >= vocab_size):
+                            logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
+                            continue
+
+                        token_content = token_json["content"]
+                        token_type = SentencePieceTokenTypes.USER_DEFINED
+                        token_score = -10000.0
+
+                        # Map unk_token to UNKNOWN, other special tokens to CONTROL
+                        # Set the score to 0.0 as in the original tokenizer.model
+                        if ("special" in token_json) and token_json["special"]:
+                            if token_content == tokenizer_config_json["unk_token"]:
+                                token_type = SentencePieceTokenTypes.UNKNOWN
+                            else:
+                                token_type = SentencePieceTokenTypes.CONTROL
+                            token_score = 0.0
+
+                        logger.info(f"Setting added token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})")
+                        tokens[token_id] = token_content.encode("utf-8")
+                        toktypes[token_id] = token_type
+                        scores[token_id] = token_score
+
+        self.gguf_writer.add_tokenizer_model("llama")
+        self.gguf_writer.add_tokenizer_pre("default")
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_scores(scores)
+        self.gguf_writer.add_token_types(toktypes)
+
+        special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
+        special_vocab.add_to_gguf(self.gguf_writer)
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        hparams = self.hparams
+        self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+        self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
+
+    _experts: list[dict[str, Tensor]] | None = None
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        n_head = self.hparams["num_attention_heads"]
+        n_kv_head = self.hparams.get("num_key_value_heads")
+
+        if name.endswith("q_proj.weight"):
+            data_torch = LlamaModel.permute(data_torch, n_head, n_head)
+        if name.endswith("k_proj.weight"):
+            data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
+
+        # process the experts separately
+        if name.find("block_sparse_moe.experts") != -1:
+            n_experts = self.hparams["num_local_experts"]
+
+            assert bid is not None
+
+            if self._experts is None:
+                self._experts = [{} for _ in range(self.block_count)]
+
+            self._experts[bid][name] = data_torch
+
+            if len(self._experts[bid]) >= n_experts * 3:
+                tensors: list[tuple[str, Tensor]] = []
+
+                # merge the experts into a single 3d tensor
+                for wid in ["w1", "w2", "w3"]:
+                    datas: list[Tensor] = []
+
+                    for xid in range(n_experts):
+                        ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight"
+                        datas.append(self._experts[bid][ename])
+                        del self._experts[bid][ename]
+
+                    data_torch = torch.stack(datas, dim=0)
+
+                    merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight"
+
+                    new_name = self.map_tensor_name(merged_name)
+
+                    tensors.append((new_name, data_torch))
+                return tensors
+            else:
+                return []
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+    def write_tensors(self):
+        super().write_tensors()
+
+        if self._experts is not None:
+            # flatten `list[dict[str, Tensor]]` into `list[str]`
+            experts = [k for d in self._experts for k in d.keys()]
+            if len(experts) > 0:
+                raise ValueError(f"Unprocessed experts: {experts}")
+
+
 ###### CONVERSION LOGIC ######
 
 
index 67e23dcc148402d7671cf959991e62919a119c03..c9ae259e1d6278cafb23c599cbe857d925ab9e80 100644 (file)
@@ -139,6 +139,7 @@ class MODEL_ARCH(IntEnum):
     COMMAND_R  = auto()
     DBRX       = auto()
     OLMO       = auto()
+    ARCTIC     = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -167,6 +168,7 @@ class MODEL_TENSOR(IntEnum):
     FFN_DOWN           = auto()
     FFN_UP             = auto()
     FFN_ACT            = auto()
+    FFN_NORM_EXP       = auto()
     FFN_GATE_EXP       = auto()
     FFN_DOWN_EXP       = auto()
     FFN_UP_EXP         = auto()
@@ -218,6 +220,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.COMMAND_R:      "command-r",
     MODEL_ARCH.DBRX:           "dbrx",
     MODEL_ARCH.OLMO:           "olmo",
+    MODEL_ARCH.ARCTIC:         "arctic",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -251,6 +254,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.FFN_DOWN_SHEXP:     "blk.{bid}.ffn_down_shexp",
     MODEL_TENSOR.FFN_UP_SHEXP:       "blk.{bid}.ffn_up_shexp",
     MODEL_TENSOR.FFN_ACT:            "blk.{bid}.ffn",
+    MODEL_TENSOR.FFN_NORM_EXP:       "blk.{bid}.ffn_norm_exps",
     MODEL_TENSOR.FFN_GATE_EXP:       "blk.{bid}.ffn_gate_exps",
     MODEL_TENSOR.FFN_DOWN_EXP:       "blk.{bid}.ffn_down_exps",
     MODEL_TENSOR.FFN_UP_EXP:         "blk.{bid}.ffn_up_exps",
@@ -732,6 +736,27 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.ARCTIC: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_NORM_EXP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+    ],
     # TODO
 }
 
index 8e1cac9152f55ea2491e9dc31d3c7b78a53c042d..8b1b21d78bb09835d5156f3cb69e3d62b566bf5c 100644 (file)
@@ -244,6 +244,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.mlp.fc11",                          # nomic-bert
             "model.layers.{bid}.mlp.c_fc",                            # starcoder2
             "encoder.layer.{bid}.mlp.gated_layers_v",                 # jina-bert-v2
+            "model.layers.{bid}.residual_mlp.w3",                     # arctic
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
@@ -272,6 +273,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.mlp.fc12",              # nomic-bert
             "encoder.layer.{bid}.mlp.gated_layers_w",     # jina-bert-v2
             "transformer.h.{bid}.mlp.linear_1",           # refact
+            "model.layers.{bid}.residual_mlp.w1",         # arctic
         ),
 
         MODEL_TENSOR.FFN_GATE_EXP: (
@@ -306,6 +308,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.mlp.fc2",                           # nomic-bert
             "model.layers.{bid}.mlp.c_proj",                          # starcoder2
             "encoder.layer.{bid}.mlp.wo",                             # jina-bert-v2
+            "model.layers.{bid}.residual_mlp.w2",                     # arctic
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -382,6 +385,18 @@ class TensorNameMap:
         ),
     }
 
+    # architecture-specific block mappings
+    arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
+        MODEL_ARCH.ARCTIC: {
+            MODEL_TENSOR.FFN_NORM: (
+                "model.layers.{bid}.residual_layernorm",
+            ),
+            MODEL_TENSOR.FFN_NORM_EXP: (
+                "model.layers.{bid}.post_attention_layernorm",
+            ),
+        },
+    }
+
     mapping: dict[str, tuple[MODEL_TENSOR, str]]
 
     def __init__(self, arch: MODEL_ARCH, n_blocks: int):
@@ -393,12 +408,14 @@ class TensorNameMap:
             self.mapping[tensor_name] = (tensor, tensor_name)
             for key in keys:
                 self.mapping[key] = (tensor, tensor_name)
+        if arch in self.arch_block_mappings_cfg:
+            self.block_mappings_cfg.update(self.arch_block_mappings_cfg[arch])
         for bid in range(n_blocks):
             for tensor, keys in self.block_mappings_cfg.items():
                 if tensor not in MODEL_TENSORS[arch]:
                     continue
                 # TODO: make this configurable
-                n_experts = 60
+                n_experts = 128
                 for xid in range(n_experts):
                     tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
                     self.mapping[tensor_name] = (tensor, tensor_name)
index 15c66077525a71eb2bdfbce2e6c01cf6509b5741..3c9fe15bb459688bc222a5d91a047000e1d4fbbd 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
 #endif
 
 #define LLAMA_MAX_NODES   8192
-#define LLAMA_MAX_EXPERTS 60
+#define LLAMA_MAX_EXPERTS 128
 
 //
 // logging
@@ -221,6 +221,7 @@ enum llm_arch {
     LLM_ARCH_COMMAND_R,
     LLM_ARCH_DBRX,
     LLM_ARCH_OLMO,
+    LLM_ARCH_ARCTIC,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -257,6 +258,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_COMMAND_R,       "command-r"    },
     { LLM_ARCH_DBRX,            "dbrx"         },
     { LLM_ARCH_OLMO,            "olmo"         },
+    { LLM_ARCH_ARCTIC,          "arctic"       },
     { LLM_ARCH_UNKNOWN,         "(unknown)"    },
 };
 
@@ -455,6 +457,7 @@ enum llm_tensor {
     LLM_TENSOR_FFN_DOWN_EXP,  // split experts for backward compatibility
     LLM_TENSOR_FFN_GATE_EXP,
     LLM_TENSOR_FFN_UP_EXP,
+    LLM_TENSOR_FFN_NORM_EXPS,
     LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
     LLM_TENSOR_FFN_GATE_EXPS,
     LLM_TENSOR_FFN_UP_EXPS,
@@ -1032,6 +1035,28 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_ARCTIC,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_NORM_EXPS,   "blk.%d.ffn_norm_exps" },
+            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -1732,6 +1757,7 @@ enum e_model {
     MODEL_8x7B,
     MODEL_8x22B,
     MODEL_16x12B,
+    MODEL_10B_128x3_66B,
 };
 
 static const size_t kiB = 1024;
@@ -1907,6 +1933,7 @@ struct llama_layer {
     struct ggml_tensor * ffn_norm_b;
     struct ggml_tensor * layer_out_norm;
     struct ggml_tensor * layer_out_norm_b;
+    struct ggml_tensor * ffn_norm_exps;
 
     // ff
     struct ggml_tensor * ffn_gate; // w1
@@ -3781,47 +3808,48 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
 
 static const char * llama_model_type_name(e_model type) {
     switch (type) {
-        case MODEL_14M:    return "14M";
-        case MODEL_17M:    return "17M";
-        case MODEL_22M:    return "22M";
-        case MODEL_33M:    return "33M";
-        case MODEL_70M:    return "70M";
-        case MODEL_109M:   return "109M";
-        case MODEL_137M:   return "137M";
-        case MODEL_160M:   return "160M";
-        case MODEL_335M:   return "335M";
-        case MODEL_410M:   return "410M";
-        case MODEL_0_5B:   return "0.5B";
-        case MODEL_1B:     return "1B";
-        case MODEL_1_4B:   return "1.4B";
-        case MODEL_2B:     return "2B";
-        case MODEL_2_8B:   return "2.8B";
-        case MODEL_3B:     return "3B";
-        case MODEL_4B:     return "4B";
-        case MODEL_6_9B:   return "6.9B";
-        case MODEL_7B:     return "7B";
-        case MODEL_8B:     return "8B";
-        case MODEL_12B:    return "12B";
-        case MODEL_13B:    return "13B";
-        case MODEL_14B:    return "14B";
-        case MODEL_15B:    return "15B";
-        case MODEL_20B:    return "20B";
-        case MODEL_30B:    return "30B";
-        case MODEL_34B:    return "34B";
-        case MODEL_35B:    return "35B";
-        case MODEL_40B:    return "40B";
-        case MODEL_65B:    return "65B";
-        case MODEL_70B:    return "70B";
-        case MODEL_314B:   return "314B";
-        case MODEL_SMALL:  return "0.1B";
-        case MODEL_MEDIUM: return "0.4B";
-        case MODEL_LARGE:  return "0.8B";
-        case MODEL_XL:     return "1.5B";
-        case MODEL_A2_7B:  return "A2.7B";
-        case MODEL_8x7B:   return "8x7B";
-        case MODEL_8x22B:  return "8x22B";
-        case MODEL_16x12B: return "16x12B";
-        default:           return "?B";
+        case MODEL_14M:           return "14M";
+        case MODEL_17M:           return "17M";
+        case MODEL_22M:           return "22M";
+        case MODEL_33M:           return "33M";
+        case MODEL_70M:           return "70M";
+        case MODEL_109M:          return "109M";
+        case MODEL_137M:          return "137M";
+        case MODEL_160M:          return "160M";
+        case MODEL_335M:          return "335M";
+        case MODEL_410M:          return "410M";
+        case MODEL_0_5B:          return "0.5B";
+        case MODEL_1B:            return "1B";
+        case MODEL_1_4B:          return "1.4B";
+        case MODEL_2B:            return "2B";
+        case MODEL_2_8B:          return "2.8B";
+        case MODEL_3B:            return "3B";
+        case MODEL_4B:            return "4B";
+        case MODEL_6_9B:          return "6.9B";
+        case MODEL_7B:            return "7B";
+        case MODEL_8B:            return "8B";
+        case MODEL_12B:           return "12B";
+        case MODEL_13B:           return "13B";
+        case MODEL_14B:           return "14B";
+        case MODEL_15B:           return "15B";
+        case MODEL_20B:           return "20B";
+        case MODEL_30B:           return "30B";
+        case MODEL_34B:           return "34B";
+        case MODEL_35B:           return "35B";
+        case MODEL_40B:           return "40B";
+        case MODEL_65B:           return "65B";
+        case MODEL_70B:           return "70B";
+        case MODEL_314B:          return "314B";
+        case MODEL_SMALL:         return "0.1B";
+        case MODEL_MEDIUM:        return "0.4B";
+        case MODEL_LARGE:         return "0.8B";
+        case MODEL_XL:            return "1.5B";
+        case MODEL_A2_7B:         return "A2.7B";
+        case MODEL_8x7B:          return "8x7B";
+        case MODEL_8x22B:         return "8x22B";
+        case MODEL_16x12B:        return "16x12B";
+        case MODEL_10B_128x3_66B: return "10B+128x3.66B";
+        default:                  return "?B";
     }
 }
 
@@ -4343,6 +4371,19 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_ARCTIC:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                if (hparams.n_expert == 128) {
+                    switch (hparams.n_layer) {
+                        case 35: model.type = e_model::MODEL_10B_128x3_66B; break;
+                        default: model.type = e_model::MODEL_UNKNOWN;
+                    }
+                } else {
+                    model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -6129,6 +6170,46 @@ static bool llm_load_tensors(
                         layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
                     }
                 } break;
+            case LLM_ARCH_ARCTIC:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_embd});
+
+                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
+                        layer.ffn_norm_exps = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd});
+                        layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
+                        layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
+                        layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -10790,6 +10871,140 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_arctic() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, cur,
+                    model.layers[il].ffn_up,   NULL,
+                    model.layers[il].ffn_gate, NULL,
+                    model.layers[il].ffn_down, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            struct ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp);
+            cb(ffn_out, "ffn_out", il);
+
+            // MoE
+            cur = llm_build_norm(ctx0, inpSA, hparams,
+                    model.layers[il].ffn_norm_exps, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm_exps", il);
+
+            cur = llm_build_moe_ffn(ctx0, cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, true,
+                    cb, il);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_out);
+            cb(cur, "ffn_out", il);
+
+            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+            if (layer_dir != nullptr) {
+                cur = ggml_add(ctx0, cur, layer_dir);
+            }
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -11004,6 +11219,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_gptneox();
             } break;
+        case LLM_ARCH_ARCTIC:
+            {
+                result = llm.build_arctic();
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -16015,6 +16234,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_XVERSE:
         case LLM_ARCH_COMMAND_R:
         case LLM_ARCH_OLMO:
+        case LLM_ARCH_ARCTIC:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2