]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add grok-1 support (#6204)
authorJulius Arkenberg <redacted>
Sat, 23 Mar 2024 16:41:53 +0000 (17:41 +0100)
committerGitHub <redacted>
Sat, 23 Mar 2024 16:41:53 +0000 (18:41 +0200)
* Add support for Grok model architecture

* Revert convert-hf-to-gguf to default options

* Fixed f_norm_rms_eps bug

* Fix whitespaces

* llama : fix grok rope type

* llama : minor

---------

Co-authored-by: Georgi Gerganov <redacted>
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp

index 1e49d56c19514ec9ca171cd8ce5c6a47d2ce1284..723ea18e34c658dcaedda8e13f53084e0eb5b084 100755 (executable)
@@ -93,31 +93,42 @@ class Model(ABC):
 
         if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
             self.gguf_writer.add_context_length(n_ctx)
+            print(f"gguf: context length = {n_ctx}")
 
         n_embd = self.find_hparam(["hidden_size", "n_embd"])
         self.gguf_writer.add_embedding_length(n_embd)
+        print(f"gguf: embedding length = {n_embd}")
 
         if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
             self.gguf_writer.add_feed_forward_length(n_ff)
+            print(f"gguf: feed forward length = {n_ff}")
 
         n_head = self.find_hparam(["num_attention_heads", "n_head"])
         self.gguf_writer.add_head_count(n_head)
+        print(f"gguf: head count = {n_head}")
 
         if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
             self.gguf_writer.add_head_count_kv(n_head_kv)
+            print(f"gguf: key-value head count = {n_head_kv}")
 
         if (rope_theta := self.hparams.get("rope_theta")) is not None:
             self.gguf_writer.add_rope_freq_base(rope_theta)
+            print(f"gguf: rope theta = {rope_theta}")
         if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
             self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
+            print(f"gguf: rms norm epsilon = {f_rms_eps}")
         if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
             self.gguf_writer.add_layer_norm_eps(f_norm_eps)
+            print(f"gguf: layer norm epsilon = {f_norm_eps}")
         if (n_experts := self.hparams.get("num_local_experts")) is not None:
             self.gguf_writer.add_expert_count(n_experts)
+            print(f"gguf: expert count = {n_experts}")
         if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
             self.gguf_writer.add_expert_used_count(n_experts_used)
+            print(f"gguf: experts used count = {n_experts_used}")
 
         self.gguf_writer.add_file_type(self.ftype)
+        print(f"gguf: file type = {self.ftype}")
 
     def write_tensors(self):
         block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
@@ -1051,6 +1062,21 @@ class MixtralModel(Model):
         self._set_vocab_sentencepiece()
 
 
+@Model.register("GrokForCausalLM")
+class GrokModel(Model):
+    model_arch = gguf.MODEL_ARCH.GROK
+
+    def set_vocab(self):
+        self._set_vocab_sentencepiece()
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_name("Grok")
+
+
 @Model.register("MiniCPMForCausalLM")
 class MiniCPMModel(Model):
     model_arch = gguf.MODEL_ARCH.MINICPM
index 4a4facb06ea14a0c4ba92f0f173ed6b3bc1604d2..e47896e2a9d3eb74d77e2b5579ffa9378ab54d62 100644 (file)
@@ -100,6 +100,7 @@ class MODEL_ARCH(IntEnum):
     LLAMA      = auto()
     FALCON     = auto()
     BAICHUAN   = auto()
+    GROK       = auto()
     GPT2       = auto()
     GPTJ       = auto()
     GPTNEOX    = auto()
@@ -167,6 +168,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.LLAMA:          "llama",
     MODEL_ARCH.FALCON:         "falcon",
     MODEL_ARCH.BAICHUAN:       "baichuan",
+    MODEL_ARCH.GROK:           "grok",
     MODEL_ARCH.GPT2:           "gpt2",
     MODEL_ARCH.GPTJ:           "gptj",
     MODEL_ARCH.GPTNEOX:        "gptneox",
@@ -251,6 +253,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
     ],
+    MODEL_ARCH.GROK: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.ATTN_OUT_NORM,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.LAYER_OUT_NORM,
+    ],
     MODEL_ARCH.GPTNEOX: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index ed89955d8970f1a57b1f5a62d3e23ea6c3f8fc78..11fd34b8b91038cd5aa10999125764a839852da1 100644 (file)
@@ -23,6 +23,7 @@ class TensorNameMap:
             "model.embedding",                           # mamba-qbert
             "backbone.embedding",                        # mamba
             "backbone.embeddings",                       # mamba-hf
+            "transformer.in_out_embed",                  # Grok
         ),
 
         # Token type embeddings
@@ -66,6 +67,7 @@ class TensorNameMap:
             "lm_head.ln",                              # phi2
             "model.norm_f",                            # mamba-qbert
             "backbone.norm_f",                         # mamba
+            "transformer.rms_norm",                    # Grok
         ),
 
         # Rope frequencies
@@ -93,6 +95,7 @@ class TensorNameMap:
             "model.layers.{bid}.attention_norm",                    # internlm2
             "model.layers.{bid}.norm",                              # mamba-qbert
             "backbone.layers.{bid}.norm",                           # mamba
+            "transformer.decoder_layer.{bid}.rms_norm",             # Grok
         ),
 
         # Attention norm 2
@@ -116,32 +119,35 @@ class TensorNameMap:
 
         # Attention query
         MODEL_TENSOR.ATTN_Q: (
-            "model.layers.{bid}.self_attn.q_proj",         # llama-hf
-            "layers.{bid}.attention.wq",                   # llama-pth
-            "encoder.layer.{bid}.attention.self.query",    # bert
-            "transformer.h.{bid}.attn.q_proj",             # gpt-j
-            "model.layers.layers.{bid}.self_attn.q_proj",  # plamo
-            "model.layers.{bid}.attention.wq"             # internlm2
+            "model.layers.{bid}.self_attn.q_proj",                       # llama-hf
+            "layers.{bid}.attention.wq",                                 # llama-pth
+            "encoder.layer.{bid}.attention.self.query",                  # bert
+            "transformer.h.{bid}.attn.q_proj",                           # gpt-j
+            "model.layers.layers.{bid}.self_attn.q_proj",                # plamo
+            "model.layers.{bid}.attention.wq",                           # internlm2
+            "transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
         ),
 
         # Attention key
         MODEL_TENSOR.ATTN_K: (
-            "model.layers.{bid}.self_attn.k_proj",         # llama-hf
-            "layers.{bid}.attention.wk",                   # llama-pth
-            "encoder.layer.{bid}.attention.self.key",      # bert
-            "transformer.h.{bid}.attn.k_proj",             # gpt-j
-            "model.layers.layers.{bid}.self_attn.k_proj",  # plamo
-            "model.layers.{bid}.attention.wk"             # internlm2
+            "model.layers.{bid}.self_attn.k_proj",                     # llama-hf
+            "layers.{bid}.attention.wk",                               # llama-pth
+            "encoder.layer.{bid}.attention.self.key",                  # bert
+            "transformer.h.{bid}.attn.k_proj",                         # gpt-j
+            "model.layers.layers.{bid}.self_attn.k_proj",              # plamo
+            "model.layers.{bid}.attention.wk",                         # internlm2
+            "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
         ),
 
         # Attention value
         MODEL_TENSOR.ATTN_V: (
-            "model.layers.{bid}.self_attn.v_proj",         # llama-hf
-            "layers.{bid}.attention.wv",                   # llama-pth
-            "encoder.layer.{bid}.attention.self.value",    # bert
-            "transformer.h.{bid}.attn.v_proj",             # gpt-j
-            "model.layers.layers.{bid}.self_attn.v_proj",  # plamo
-            "model.layers.{bid}.attention.wv"             # internlm2
+            "model.layers.{bid}.self_attn.v_proj",                       # llama-hf
+            "layers.{bid}.attention.wv",                                 # llama-pth
+            "encoder.layer.{bid}.attention.self.value",                  # bert
+            "transformer.h.{bid}.attn.v_proj",                           # gpt-j
+            "model.layers.layers.{bid}.self_attn.v_proj",                # plamo
+            "model.layers.{bid}.attention.wv",                           # internlm2
+            "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
         ),
 
         # Attention output
@@ -162,12 +168,14 @@ class TensorNameMap:
             "model.layers.layers.{bid}.self_attn.o_proj",                # plamo
             "model.layers.{bid}.attention.wo",                           # internlm2
             "encoder.layers.{bid}.attn.out_proj",                        # nomic-bert
+            "transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
         ),
 
         # Attention output norm
         MODEL_TENSOR.ATTN_OUT_NORM: (
             "encoder.layer.{bid}.attention.output.LayerNorm",  # bert
             "encoder.layers.{bid}.norm1",                      # nomic-bert
+            "transformer.decoder_layer.{bid}.rms_norm_1",      # Grok
         ),
 
         # Rotary embeddings
@@ -190,11 +198,13 @@ class TensorNameMap:
             "model.layers.{bid}.ln2",                                        # yi
             "h.{bid}.ln_2",                                                  # gpt2
             "model.layers.{bid}.ffn_norm",                                   # internlm2
+            "transformer.decoder_layer.{bid}.rms_norm_2",                    # Grok
         ),
 
         MODEL_TENSOR.FFN_GATE_INP: (
             "layers.{bid}.feed_forward.gate",           # mixtral
             "model.layers.{bid}.block_sparse_moe.gate", # mixtral
+            "transformer.decoder_layer.{bid}.router"    # Grok
         ),
 
         # Feed-forward up
@@ -223,6 +233,7 @@ class TensorNameMap:
         MODEL_TENSOR.FFN_UP_EXP: (
             "layers.{bid}.feed_forward.experts.{xid}.w3",           # mixtral
             "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
+            "transformer.decoder_layer.{bid}.moe.{xid}.linear_v",   # Grok
         ),
 
         # AWQ-activation gate
@@ -243,6 +254,7 @@ class TensorNameMap:
         MODEL_TENSOR.FFN_GATE_EXP: (
             "layers.{bid}.feed_forward.experts.{xid}.w1",           # mixtral
             "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral
+            "transformer.decoder_layer.{bid}.moe.{xid}.linear"      # Grok
         ),
 
         # Feed-forward down
@@ -270,6 +282,8 @@ class TensorNameMap:
         MODEL_TENSOR.FFN_DOWN_EXP: (
             "layers.{bid}.feed_forward.experts.{xid}.w2",           # mixtral
             "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral
+            "transformer.decoder_layer.{bid}.moe.{xid}.linear_1",   # Grok
+
         ),
 
         MODEL_TENSOR.ATTN_Q_NORM: (
@@ -287,8 +301,9 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.LAYER_OUT_NORM: (
-            "encoder.layer.{bid}.output.LayerNorm",  # bert
-            "encoder.layers.{bid}.norm2",            # nomic-bert
+            "encoder.layer.{bid}.output.LayerNorm",         # bert
+            "encoder.layers.{bid}.norm2",                   # nomic-bert
+            "transformer.decoder_layer.{bid}.rms_norm_3",   # Grok
         ),
 
         MODEL_TENSOR.SSM_IN: (
index eedca802b86a78442f43a9f588881ef317d660ce..4e08be18d7c275f72f0e821b09b8cda536ce3c40 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -195,6 +195,7 @@ enum llm_arch {
     LLM_ARCH_LLAMA,
     LLM_ARCH_FALCON,
     LLM_ARCH_BAICHUAN,
+    LLM_ARCH_GROK,
     LLM_ARCH_GPT2,
     LLM_ARCH_GPTJ,
     LLM_ARCH_GPTNEOX,
@@ -224,6 +225,7 @@ enum llm_arch {
 static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_LLAMA,           "llama"      },
     { LLM_ARCH_FALCON,          "falcon"     },
+    { LLM_ARCH_GROK,            "grok"       },
     { LLM_ARCH_GPT2,            "gpt2"       },
     { LLM_ARCH_GPTJ,            "gptj"       },
     { LLM_ARCH_GPTNEOX,         "gptneox"    },
@@ -494,6 +496,28 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_GROK,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
+            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
+            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
+            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
+            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
+        },
+    },
     {
         LLM_ARCH_GPT2,
         {
@@ -1635,6 +1659,7 @@ enum e_model {
     MODEL_40B,
     MODEL_65B,
     MODEL_70B,
+    MODEL_314B,
     MODEL_SMALL,
     MODEL_MEDIUM,
     MODEL_LARGE,
@@ -3419,6 +3444,7 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_40B:    return "40B";
         case MODEL_65B:    return "65B";
         case MODEL_70B:    return "70B";
+        case MODEL_314B:   return "314B";
         case MODEL_SMALL:  return "0.1B";
         case MODEL_MEDIUM: return "0.4B";
         case MODEL_LARGE:  return "0.8B";
@@ -3557,6 +3583,15 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_GROK:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 64: model.type = e_model::MODEL_314B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_FALCON:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -4394,6 +4429,54 @@ static bool llm_load_tensors(
                         }
                     }
                 } break;
+            case LLM_ARCH_GROK:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, false);
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                            ml.n_created--; // artificial tensor
+                            ml.size_data += ggml_nbytes(model.output);
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd});
+
+                        GGML_ASSERT(hparams.n_expert      > 0);
+                        GGML_ASSERT(hparams.n_expert_used > 0);
+
+                        // MoE branch
+                        for (uint32_t x = 0; x < hparams.n_expert; ++x) {
+                            layer.ffn_gate_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd,   n_ff});
+                            layer.ffn_down_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), {  n_ff, n_embd});
+                            layer.ffn_up_exp[x]   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, x), {n_embd,   n_ff});
+                        }
+
+                        layer.layer_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
+                    }
+                } break;
             case LLM_ARCH_BAICHUAN:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -5621,6 +5704,20 @@ static struct ggml_tensor * llm_build_kqv(
         ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
     }
 
+    if (model.arch == LLM_ARCH_GROK) {
+        // need to do the following:
+        // multiply by attn_output_multiplyer of 0.08838834764831845
+        // and then :
+        // kq = 30 * tanh(kq / 30)
+        // before the softmax below
+
+        //try from phi2
+        //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+
+        kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
+        kq = ggml_scale(ctx, kq, 30);
+    }
+
 #if defined(GGML_USE_KOMPUTE)
 #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute")
 #pragma message("      Falling back to ggml_alibi(). Will become an error in Mar 2024")
@@ -6395,6 +6492,203 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_grok() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // multiply by embedding_multiplier_scale of 78.38367176906169
+        inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+            }
+
+            // Grok
+            // if attn_out_norm is present then apply it before adding the input
+            if (model.layers[il].attn_out_norm) {
+                cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].attn_out_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_out_norm", il);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            // MoE branch
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
+            cb(logits, "ffn_moe_logits", il);
+
+            ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
+            cb(probs, "ffn_moe_probs", il);
+
+            // select experts
+            ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
+            cb(selected_experts->src[0], "ffn_moe_argsort", il);
+
+            ggml_tensor * weights = ggml_get_rows(ctx0,
+                    ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
+            cb(weights, "ffn_moe_weights", il);
+
+            weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
+
+            ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
+            cb(weights_sum, "ffn_moe_weights_sum", il);
+
+            weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
+            cb(weights, "ffn_moe_weights_norm", il);
+
+            // compute expert outputs
+            ggml_tensor * moe_out = nullptr;
+
+            for (int i = 0; i < n_expert_used; ++i) {
+                ggml_tensor * cur_expert;
+
+                ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur);
+                cb(cur_up, "ffn_moe_up", il);
+
+                ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur);
+                cb(cur_gate, "ffn_moe_gate", il);
+
+                //GeLU
+                cur_gate = ggml_gelu(ctx0, cur_gate);
+                cb(cur_gate, "ffn_moe_gelu", il);
+
+                cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
+                cb(cur_expert, "ffn_moe_gate_par", il);
+
+                cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd]
+                cb(cur_expert, "ffn_moe_down", il);
+
+                cur_expert = ggml_mul(ctx0, cur_expert,
+                        ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
+                cb(cur_expert, "ffn_moe_weighted", il);
+
+                if (i == 0) {
+                    moe_out = cur_expert;
+                } else {
+                    moe_out = ggml_add(ctx0, moe_out, cur_expert);
+                    cb(moe_out, "ffn_moe_out", il);
+                }
+            }
+
+            cur = moe_out;
+
+            // Grok
+            // if layer_out_norm is present then apply it before adding the input
+            // Idea: maybe ffn_out_norm is a better name
+            if (model.layers[il].layer_out_norm) {
+                cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].layer_out_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "layer_out_norm", il);
+            }
+
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+            if (layer_dir != nullptr) {
+                cur = ggml_add(ctx0, cur, layer_dir);
+            }
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+
+        // Grok
+        // multiply logits by output_multiplier_scale of 0.5773502691896257
+
+        cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_starcoder() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -8818,6 +9112,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_falcon();
             } break;
+        case LLM_ARCH_GROK:
+            {
+                result = llm.build_grok();
+            } break;
         case LLM_ARCH_STARCODER:
             {
                 result = llm.build_starcoder();
@@ -13561,6 +13859,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
 
         // the pairs of head values are offset by n_rot/2
         case LLM_ARCH_FALCON:
+        case LLM_ARCH_GROK:
         case LLM_ARCH_PERSIMMON:
         case LLM_ARCH_BERT:
         case LLM_ARCH_NOMIC_BERT: