]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model: support arch `DbrxForCausalLM` (#6515)
authorPierrick Hymbert <redacted>
Sat, 13 Apr 2024 09:33:52 +0000 (11:33 +0200)
committerGitHub <redacted>
Sat, 13 Apr 2024 09:33:52 +0000 (11:33 +0200)
* model: dbrx convert to gguf
#6344

* llama: support dbrx
#6344

* doc: dbrx: add the model as supported

* scripts: get-wikitext-2 add unzip

* llama: increase maximum experts allowed

* llama: factorize moe graph implementation between grok, mixtral and dbrx

---------

Co-authored-by: Megha Agarwal <redacted>
README.md
convert-hf-to-gguf.py
examples/eval-callback/eval-callback.cpp
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp
scripts/get-wikitext-2.sh

index 00a487fc6868eaa3a1db805e3e4c28d1a6281315..fa2a4db834df2825761af54e236757539f3db82c 100644 (file)
--- a/README.md
+++ b/README.md
@@ -94,6 +94,7 @@ Typically finetunes of the base models below are supported as well.
 - [x] LLaMA 2 ðŸ¦™ðŸ¦™
 - [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
 - [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
+- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
 - [X] Falcon
 - [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
 - [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
index 63710676bad1c485bea4fbb951f2e1a200f914af..e1ac09e024b117583f98bb22aad232fb7424f18d 100755 (executable)
@@ -1427,6 +1427,102 @@ class GrokModel(Model):
             self.gguf_writer.add_tensor(new_name, data)
 
 
+@Model.register("DbrxForCausalLM")
+class DbrxModel(Model):
+    model_arch = gguf.MODEL_ARCH.DBRX
+
+    def set_gguf_parameters(self):
+        ffn_config = self.hparams["ffn_config"]
+        attn_config = self.hparams["attn_config"]
+        self.gguf_writer.add_name(self.hparams["model_type"])
+        self.gguf_writer.add_block_count(self.hparams["n_layers"])
+
+        self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
+        self.gguf_writer.add_embedding_length(self.hparams["d_model"])
+        self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"])
+
+        self.gguf_writer.add_head_count(self.hparams["n_heads"])
+        self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])
+
+        self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
+
+        self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
+        self.gguf_writer.add_file_type(self.ftype)
+
+        self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
+        self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
+
+        self.gguf_writer.add_layer_norm_eps(1e-5)
+
+        self.gguf_writer.add_file_type(self.ftype)
+        print(f"gguf: file type = {self.ftype}")
+
+    def write_tensors(self):
+        block_count = self.hparams.get("n_layers")
+        tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+        for name, data_torch in self.get_tensors():
+            n_expert = self.hparams["ffn_config"]["moe_num_experts"]
+            n_ff = self.hparams["ffn_config"]["ffn_hidden_size"]
+            n_embd = self.hparams["d_model"]
+
+            # Specific behavior for experts tensors: suffix .weight, view as 3D and transpose
+            # original implementation expects (n_expert, n_ff, n_embd) for all experts weights
+            # But llama.cpp moe graph works differently
+            # AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions
+            # so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor
+            exp_tensor_names = {"ffn.experts.mlp.w1": None,       # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff,   n_expert}
+                                "ffn.experts.mlp.w2": (0, 2, 1),  # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff,   n_embd, n_expert}
+                                "ffn.experts.mlp.v1": None}       # LLM_TENSOR_FFN_UP_EXPS   ggml_tensor->ne{n_embd, n_ff,   n_expert}
+            experts = False
+            for exp_tensor_name in exp_tensor_names.keys():
+                if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
+                    experts = True
+                    data_torch = data_torch.view(n_expert, n_ff, n_embd)
+                    if (permute_tensor := exp_tensor_names[exp_tensor_name]) is not None:
+                        data_torch = data_torch.permute(*permute_tensor)
+                    break
+
+            old_dtype = data_torch.dtype
+
+            # convert any unsupported data types to float32
+            if data_torch.dtype not in (torch.float16, torch.float32):
+                data_torch = data_torch.to(torch.float32)
+
+            data = data_torch.squeeze().numpy()
+
+            # map tensor names
+            # In MoE models the ffn tensors are typically most of the model weights,
+            # and need to be quantizable. Quantize expects tensor names to be suffixed by .weight.
+            # Every other model has the weight names ending in .weight,
+            # let's assume that is the convention which is not the case for dbrx:
+            # https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
+            new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
+            if new_name is None:
+                print(f"Can not map tensor {name!r}")
+                sys.exit()
+
+            n_dims = len(data.shape)
+            data_dtype = data.dtype
+
+            # Most of the codebase that takes in 1D tensors only handles F32 tensors
+            # and most of the outputs tensors are F32.
+            if data_dtype != np.float32 and n_dims == 1:
+                print(f"Can not map tensor {name!r}: all 1D tensors must be F32")
+                sys.exit()
+
+            # if f32 desired, convert any float16 to float32
+            if self.ftype == 0 and data_dtype == np.float16:
+                data = data.astype(np.float32)
+
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            if self.ftype == 1 and data_dtype == np.float32 and n_dims > 1:
+                data = data.astype(np.float16)
+
+            print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
+
+            self.gguf_writer.add_tensor(new_name, data)
+
+
 @Model.register("MiniCPMForCausalLM")
 class MiniCPMModel(Model):
     model_arch = gguf.MODEL_ARCH.MINICPM
index 05f7d6ab1a2d333296fb1f0798b07605a1bd57be..29b5f3b3c12c83d5669f10a5c3871791f03ae6f4 100644 (file)
@@ -28,14 +28,27 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
 }
 
 static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
+    GGML_ASSERT(n > 0);
     float sum = 0;
     for (int64_t i3 = 0; i3 < ne[3]; i3++) {
         printf("                                     [\n");
-        for (int64_t i2 = 0; i2 < ne[2] && i2 < n; i2++) {
+        for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+            if (i2 == n && ne[2] > 2*n) {
+                printf("                                      ..., \n");
+                i2 = ne[2] - n;
+            }
             printf("                                      [\n");
-            for (int64_t i1 = 0; i1 < ne[1] && i1 < n; i1++) {
+            for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+                if (i1 == n && ne[1] > 2*n) {
+                    printf("                                       ..., \n");
+                    i1 = ne[1] - n;
+                }
                 printf("                                       [");
-                for (int64_t i0 = 0; i0 < ne[0] && i0 < n; i0++) {
+                for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+                    if (i0 == n && ne[0] > 2*n) {
+                        printf("..., ");
+                        i0 = ne[0] - n;
+                    }
                     size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
                     float v;
                     if (type == GGML_TYPE_F16) {
@@ -51,17 +64,14 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
                     } else {
                         GGML_ASSERT(false);
                     }
-                    printf("%8.4f", v);
+                    printf("%12.4f", v);
                     sum += v;
-                    if (i0 < ne[0] - 1 && i0 < n - 1) printf(", ");
+                    if (i0 < ne[0] - 1) printf(", ");
                 }
-                if (ne[0] > n) printf(", ...");
                 printf("],\n");
             }
-            if (ne[1] > n) printf("                                       ...\n");
             printf("                                      ],\n");
         }
-        if (ne[2] > n) printf("                                     ...\n");
         printf("                                     ]\n");
         printf("                                     sum = %f\n", sum);
     }
index a6454a10e20b95adbb58ecde087001168ed97e85..2566b2fb8f2966084dba6189f9783ba3528349ac 100644 (file)
@@ -126,6 +126,7 @@ class MODEL_ARCH(IntEnum):
     MAMBA      = auto()
     XVERSE     = auto()
     COMMAND_R  = auto()
+    DBRX       = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -195,6 +196,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.MAMBA:          "mamba",
     MODEL_ARCH.XVERSE:         "xverse",
     MODEL_ARCH.COMMAND_R:      "command-r",
+    MODEL_ARCH.DBRX:           "dbrx",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -642,6 +644,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ATTN_K_NORM,
         MODEL_TENSOR.ATTN_Q_NORM,
     ],
+    MODEL_ARCH.DBRX: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_OUT_NORM,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+    ],
     # TODO
 }
 
index 4f02d298e13debb8ccbf52d9abc826c05a9b7a82..ec6fcbb838425633f70e76aef5865a0a385f335a 100644 (file)
@@ -10,7 +10,7 @@ class TensorNameMap:
         # Token embeddings
         MODEL_TENSOR.TOKEN_EMBD: (
             "gpt_neox.embed_in",                         # gptneox
-            "transformer.wte",                           # gpt2 gpt-j mpt refact qwen
+            "transformer.wte",                           # gpt2 gpt-j mpt refact qwen dbrx
             "transformer.word_embeddings",               # falcon
             "word_embeddings",                           # bloom
             "model.embed_tokens",                        # llama-hf
@@ -48,7 +48,7 @@ class TensorNameMap:
         # Output
         MODEL_TENSOR.OUTPUT: (
             "embed_out",                 # gptneox
-            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba
+            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
             "output",                    # llama-pth bloom internlm2
             "word_embeddings_for_head",  # persimmon
             "lm_head.linear",            # phi2
@@ -60,7 +60,7 @@ class TensorNameMap:
             "transformer.ln_f",                        # gpt2 gpt-j falcon
             "model.norm",                              # llama-hf baichuan internlm2
             "norm",                                    # llama-pth
-            "transformer.norm_f",                      # mpt
+            "transformer.norm_f",                      # mpt dbrx
             "ln_f",                                    # refact bloom qwen gpt2
             "language_model.encoder.final_layernorm",  # persimmon
             "model.final_layernorm",                   # persimmon
@@ -96,6 +96,7 @@ class TensorNameMap:
             "model.layers.{bid}.norm",                              # mamba-qbert
             "backbone.layers.{bid}.norm",                           # mamba
             "transformer.decoder_layer.{bid}.rms_norm",             # Grok
+            "transformer.blocks.{bid}.norm_attn_norm.norm_1",       # dbrx
         ),
 
         # Attention norm 2
@@ -108,6 +109,7 @@ class TensorNameMap:
             "gpt_neox.layers.{bid}.attention.query_key_value",                     # gptneox
             "transformer.h.{bid}.attn.c_attn",                                     # gpt2 qwen
             "transformer.blocks.{bid}.attn.Wqkv",                                  # mpt
+            "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv",                   # dbrx
             "transformer.h.{bid}.self_attention.query_key_value",                  # falcon
             "h.{bid}.self_attention.query_key_value",                              # bloom
             "language_model.encoder.layers.{bid}.self_attention.query_key_value",  # persimmon
@@ -152,23 +154,24 @@ class TensorNameMap:
 
         # Attention output
         MODEL_TENSOR.ATTN_OUT: (
-            "gpt_neox.layers.{bid}.attention.dense",                     # gptneox
-            "transformer.h.{bid}.attn.c_proj",                           # gpt2 refact qwen
-            "transformer.blocks.{bid}.attn.out_proj",                    # mpt
-            "transformer.h.{bid}.self_attention.dense",                  # falcon
-            "h.{bid}.self_attention.dense",                              # bloom
-            "model.layers.{bid}.self_attn.o_proj",                       # llama-hf
-            "layers.{bid}.attention.wo",                                 # llama-pth
-            "encoder.layer.{bid}.attention.output.dense",                # bert
-            "transformer.h.{bid}.attn.out_proj",                         # gpt-j
-            "language_model.encoder.layers.{bid}.self_attention.dense",  # persimmon
-            "model.layers.{bid}.self_attn.dense",                        # persimmon
-            "h.{bid}.attn.c_proj",                                       # gpt2
-            "transformer.h.{bid}.mixer.out_proj",                        # phi2
-            "model.layers.layers.{bid}.self_attn.o_proj",                # plamo
-            "model.layers.{bid}.attention.wo",                           # internlm2
-            "encoder.layers.{bid}.attn.out_proj",                        # nomic-bert
-            "transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
+            "gpt_neox.layers.{bid}.attention.dense",                        # gptneox
+            "transformer.h.{bid}.attn.c_proj",                              # gpt2 refact qwen
+            "transformer.blocks.{bid}.attn.out_proj",                       # mpt
+            "transformer.h.{bid}.self_attention.dense",                     # falcon
+            "h.{bid}.self_attention.dense",                                 # bloom
+            "model.layers.{bid}.self_attn.o_proj",                          # llama-hf
+            "layers.{bid}.attention.wo",                                    # llama-pth
+            "encoder.layer.{bid}.attention.output.dense",                   # bert
+            "transformer.h.{bid}.attn.out_proj",                            # gpt-j
+            "language_model.encoder.layers.{bid}.self_attention.dense",     # persimmon
+            "model.layers.{bid}.self_attn.dense",                           # persimmon
+            "h.{bid}.attn.c_proj",                                          # gpt2
+            "transformer.h.{bid}.mixer.out_proj",                           # phi2
+            "model.layers.layers.{bid}.self_attn.o_proj",                   # plamo
+            "model.layers.{bid}.attention.wo",                              # internlm2
+            "encoder.layers.{bid}.attn.out_proj",                           # nomic-bert
+            "transformer.decoder_layer.{bid}.multi_head_attention.linear",  # Grok
+            "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj",        # dbrx
         ),
 
         # Attention output norm
@@ -176,6 +179,7 @@ class TensorNameMap:
             "encoder.layer.{bid}.attention.output.LayerNorm",  # bert
             "encoder.layers.{bid}.norm1",                      # nomic-bert
             "transformer.decoder_layer.{bid}.rms_norm_1",      # Grok
+            "transformer.blocks.{bid}.norm_attn_norm.norm_2",  # dbrx
         ),
 
         # Rotary embeddings
@@ -202,9 +206,10 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.FFN_GATE_INP: (
-            "layers.{bid}.feed_forward.gate",           # mixtral
-            "model.layers.{bid}.block_sparse_moe.gate", # mixtral
-            "transformer.decoder_layer.{bid}.router"    # Grok
+            "layers.{bid}.feed_forward.gate",             # mixtral
+            "model.layers.{bid}.block_sparse_moe.gate",   # mixtral
+            "transformer.decoder_layer.{bid}.router",     # Grok
+            "transformer.blocks.{bid}.ffn.router.layer",  # dbrx
         ),
 
         # Feed-forward up
@@ -233,6 +238,7 @@ class TensorNameMap:
         MODEL_TENSOR.FFN_UP_EXP: (
             "layers.{bid}.feed_forward.experts.w3",                 # mixtral (merged)
             "transformer.decoder_layer.{bid}.moe.linear_v",         # Grok (merged)
+            "transformer.blocks.{bid}.ffn.experts.mlp.v1",          # dbrx
         ),
 
         # AWQ-activation gate
@@ -251,8 +257,9 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.FFN_GATE_EXP: (
-            "layers.{bid}.feed_forward.experts.w1",                 # mixtral (merged)
-            "transformer.decoder_layer.{bid}.moe.linear"            # Grok (merged)
+            "layers.{bid}.feed_forward.experts.w1",         # mixtral (merged)
+            "transformer.decoder_layer.{bid}.moe.linear",   # Grok (merged)
+            "transformer.blocks.{bid}.ffn.experts.mlp.w1",  # dbrx
         ),
 
         # Feed-forward down
@@ -280,6 +287,7 @@ class TensorNameMap:
         MODEL_TENSOR.FFN_DOWN_EXP: (
             "layers.{bid}.feed_forward.experts.w2",                 # mixtral (merged)
             "transformer.decoder_layer.{bid}.moe.linear_1",         # Grok (merged)
+            "transformer.blocks.{bid}.ffn.experts.mlp.w2",          # dbrx
         ),
 
         MODEL_TENSOR.ATTN_Q_NORM: (
index 83dd55efa99e6849b1fbbad97d1d3db2a3358840..b93c1abcd85d6f3a6986e11c5060c8506611b268 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
 #endif
 
 #define LLAMA_MAX_NODES   8192
-#define LLAMA_MAX_EXPERTS 8
+#define LLAMA_MAX_EXPERTS 16
 
 
 //
@@ -220,6 +220,7 @@ enum llm_arch {
     LLM_ARCH_MAMBA,
     LLM_ARCH_XVERSE,
     LLM_ARCH_COMMAND_R,
+    LLM_ARCH_DBRX,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -252,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_MAMBA,           "mamba"      },
     { LLM_ARCH_XVERSE,          "xverse"     },
     { LLM_ARCH_COMMAND_R,       "command-r"  },
+    { LLM_ARCH_DBRX,            "dbrx"       },
     { LLM_ARCH_UNKNOWN,         "(unknown)"  },
 };
 
@@ -934,6 +936,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
         },
     },
+    {
+        LLM_ARCH_DBRX,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -1707,6 +1725,7 @@ enum e_model {
     MODEL_XL,
     MODEL_8x7B,
     MODEL_8x22B,
+    MODEL_16x12B,
 };
 
 static const size_t kiB = 1024;
@@ -3562,6 +3581,7 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_XL:     return "1.5B";
         case MODEL_8x7B:   return "8x7B";
         case MODEL_8x22B:  return "8x22B";
+        case MODEL_16x12B: return "16x12B";
         default:           return "?B";
     }
 }
@@ -3983,6 +4003,16 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_DBRX:
+        {
+            ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps);
+            ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,      hparams.f_clamp_kqv);
+
+            switch (hparams.n_layer) {
+                case 40: model.type = e_model::MODEL_16x12B; break;
+                default: model.type = e_model::MODEL_UNKNOWN;
+            }
+        } break;
         default: (void)0;
     }
 
@@ -4671,6 +4701,39 @@ static bool llm_load_tensors(
                         layer.layer_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
                     }
                 } break;
+            case LLM_ARCH_DBRX:
+            {
+                if (n_expert == 0) {
+                    throw std::runtime_error("DBRX model cannot have zero experts");
+                }
+
+                model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                // output
+                {
+                    model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                    model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                }
+
+                for (int i = 0; i < n_layer; ++i) {
+                    ggml_context * ctx_layer = ctx_for_layer(i);
+                    ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                    auto & layer = model.layers[i];
+
+                    layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,  "weight", i), {n_embd});
+
+                    layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+                    layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                    layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
+
+                    layer.ffn_gate_inp  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert});
+                    layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert});
+                    layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert});
+                    layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert});
+                }
+            } break;
             case LLM_ARCH_BAICHUAN:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -6433,62 +6496,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
-                cb(logits, "ffn_moe_logits", il);
-
-                ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
-                cb(probs, "ffn_moe_probs", il);
-
-                // select experts
-                ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
-                cb(selected_experts->src[0], "ffn_moe_argsort", il);
-
-                ggml_tensor * weights = ggml_get_rows(ctx0,
-                        ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
-                cb(weights, "ffn_moe_weights", il);
-
-                weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
-
-                ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
-                cb(weights_sum, "ffn_moe_weights_sum", il);
-
-                weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
-                cb(weights, "ffn_moe_weights_norm", il);
-
-                // compute expert outputs
-                ggml_tensor * moe_out = nullptr;
-
-                for (int i = 0; i < n_expert_used; ++i) {
-                    ggml_tensor * cur_expert;
-
-                    ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
-                    cb(cur_up, "ffn_moe_up", il);
-
-                    ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
-                    cb(cur_gate, "ffn_moe_gate", il);
-
-                    cur_gate = ggml_silu(ctx0, cur_gate);
-                    cb(cur_gate, "ffn_moe_silu", il);
-
-                    cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
-                    cb(cur_expert, "ffn_moe_gate_par", il);
-
-                    cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
-                    cb(cur_expert, "ffn_moe_down", il);
-
-                    cur_expert = ggml_mul(ctx0, cur_expert,
-                            ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
-                    cb(cur_expert, "ffn_moe_weighted", il);
-
-                    if (i == 0) {
-                        moe_out = cur_expert;
-                    } else {
-                        moe_out = ggml_add(ctx0, moe_out, cur_expert);
-                        cb(moe_out, "ffn_moe_out", il);
-                    }
-                }
-
-                cur = moe_out;
+                cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
@@ -6520,6 +6528,78 @@ struct llm_build_context {
         return gf;
     }
 
+    // REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505
+    ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, int il) {
+        ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
+        cb(logits, "ffn_moe_logits", il);
+
+        ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
+        cb(probs, "ffn_moe_probs", il);
+
+        // select experts
+        ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
+        cb(selected_experts->src[0], "ffn_moe_argsort", il);
+
+        ggml_tensor * weights = ggml_get_rows(ctx0,
+                                              ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
+        cb(weights, "ffn_moe_weights", il);
+
+        weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
+
+        ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
+        cb(weights_sum, "ffn_moe_weights_sum", il);
+
+        weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
+        cb(weights, "ffn_moe_weights_norm", il);
+
+        // compute expert outputs
+        ggml_tensor * moe_out = nullptr;
+
+        for (int i = 0; i < n_expert_used; ++i) {
+            ggml_tensor * cur_expert;
+
+            ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
+            cb(cur_up, "ffn_moe_up", il);
+
+            ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
+            cb(gate, "ffn_moe_gate", il);
+
+            switch (type_op) {
+                case LLM_FFN_SILU:
+                {
+                    gate = ggml_silu(ctx0, gate);
+                    cb(gate, "ffn_moe_silu", il);
+                } break;
+                case LLM_FFN_GELU:
+                {
+                    gate = ggml_gelu(ctx0, gate);
+                    cb(gate, "ffn_moe_gelu", il);
+                } break;
+                default:
+                    GGML_ASSERT(false);
+            }
+
+            cur_expert = ggml_mul(ctx0, cur_up, gate);
+            cb(cur_expert, "ffn_moe_gate_par", il);
+
+            cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
+            cb(cur_expert, "ffn_moe_down", il);
+
+            cur_expert = ggml_mul(ctx0, cur_expert,
+                                  ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
+            cb(cur_expert, "ffn_moe_weighted", il);
+
+            if (i == 0) {
+                moe_out = cur_expert;
+            } else {
+                moe_out = ggml_add(ctx0, moe_out, cur_expert);
+                cb(moe_out, "ffn_moe_out", il);
+            }
+        }
+
+        return moe_out;
+    }
+
     struct ggml_cgraph * build_baichuan() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -6967,74 +7047,143 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
-            cb(logits, "ffn_moe_logits", il);
+            cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, il);
+
+            // Grok
+            // if layer_out_norm is present then apply it before adding the input
+            // Idea: maybe ffn_out_norm is a better name
+            if (model.layers[il].layer_out_norm) {
+                cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].layer_out_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "layer_out_norm", il);
+            }
+
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+            if (layer_dir != nullptr) {
+                cur = ggml_add(ctx0, cur, layer_dir);
+            }
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
 
-            ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
-            cb(probs, "ffn_moe_probs", il);
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
 
-            // select experts
-            ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
-            cb(selected_experts->src[0], "ffn_moe_argsort", il);
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
 
-            ggml_tensor * weights = ggml_get_rows(ctx0,
-                    ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
-            cb(weights, "ffn_moe_weights", il);
+        // Grok
+        // multiply logits by output_multiplier_scale of 0.5773502691896257
 
-            weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
+        cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
 
-            ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
-            cb(weights_sum, "ffn_moe_weights_sum", il);
+        cb(cur, "result_output", -1);
 
-            weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
-            cb(weights, "ffn_moe_weights_norm", il);
+        ggml_build_forward_expand(gf, cur);
 
-            // compute expert outputs
-            ggml_tensor * moe_out = nullptr;
+        return gf;
+    }
 
-            for (int i = 0; i < n_expert_used; ++i) {
-                ggml_tensor * cur_expert;
+    struct ggml_cgraph * build_dbrx() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
-                ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
-                cb(cur_up, "ffn_moe_up", il);
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
 
-                ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
-                cb(cur_gate, "ffn_moe_gate", il);
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-                //GeLU
-                cur_gate = ggml_gelu(ctx0, cur_gate);
-                cb(cur_gate, "ffn_moe_gelu", il);
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
 
-                cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
-                cb(cur_expert, "ffn_moe_gate_par", il);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
-                cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
-                cb(cur_expert, "ffn_moe_down", il);
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
 
-                cur_expert = ggml_mul(ctx0, cur_expert,
-                        ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
-                cb(cur_expert, "ffn_moe_weighted", il);
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
-                if (i == 0) {
-                    moe_out = cur_expert;
-                } else {
-                    moe_out = ggml_add(ctx0, moe_out, cur_expert);
-                    cb(moe_out, "ffn_moe_out", il);
-                }
-            }
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
 
-            cur = moe_out;
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                                 model.layers[il].attn_norm, NULL,
+                                 LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
 
-            // Grok
-            // if layer_out_norm is present then apply it before adding the input
-            // Idea: maybe ffn_out_norm is a better name
-            if (model.layers[il].layer_out_norm) {
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].layer_out_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "layer_out_norm", il);
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = nullptr;
+                struct ggml_tensor * Kcur = nullptr;
+                struct ggml_tensor * Vcur = nullptr;
+
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                cb(cur, "wqkv_clamped", il);
+
+                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                                   model.layers[il].wo, NULL,
+                                   Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            // MoE branch
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                                 model.layers[il].attn_out_norm, NULL,
+                                 LLM_NORM, cb, il);
+            cb(cur, "attn_out_norm", il);
+
+            cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
 
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
@@ -7052,18 +7201,13 @@ struct llm_build_context {
         cur = inpL;
 
         cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
+                             model.output_norm, NULL,
+                             LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
         // lm_head
         cur = ggml_mul_mat(ctx0, model.output, cur);
 
-        // Grok
-        // multiply logits by output_multiplier_scale of 0.5773502691896257
-
-        cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
-
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9785,6 +9929,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_command_r();
             } break;
+        case LLM_ARCH_DBRX:
+            {
+                result = llm.build_dbrx();
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -14638,6 +14786,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         // the pairs of head values are offset by n_rot/2
         case LLM_ARCH_FALCON:
         case LLM_ARCH_GROK:
+        case LLM_ARCH_DBRX:
         case LLM_ARCH_PERSIMMON:
         case LLM_ARCH_BERT:
         case LLM_ARCH_NOMIC_BERT:
index 7ca760fa613045d9741a8a2ec5f88d02e3c98dbf..b01476a46015ac3a24424c2e3f38d5211675ed74 100755 (executable)
@@ -1,10 +1,11 @@
 #!/bin/bash
 
 wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
+unzip wikitext-2-raw-v1.zip
 
 echo "Usage:"
 echo ""
-echo "  ./perplexity -m model.gguf -f wiki.test.raw [other params]"
+echo "  ./perplexity -m model.gguf -f wikitext-2-raw/wiki.test.raw [other params]"
 echo ""
 
 exit 0