]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add qwen2moe (#6074)
authorShijie <redacted>
Tue, 16 Apr 2024 15:40:48 +0000 (23:40 +0800)
committerGitHub <redacted>
Tue, 16 Apr 2024 15:40:48 +0000 (18:40 +0300)
* support qwen2moe

* fix-review

* metal : support unary ops for nelements % 4 != 0

* metal : require contiguousness for float4 unary kernels

* metal : require contiguousness for float4 unary kernels (cont)

* fix-review

* names : for brevity "SHARED_EXP" -> "SHEXP"

* llama : reuse build_moe_ffn()

* llama : add model type name

---------

Co-authored-by: Georgi Gerganov <redacted>
convert-hf-to-gguf.py
ggml-metal.m
ggml-metal.metal
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp
tests/test-backend-ops.cpp

index 6d28ab5e4919cb4f39c2449f619c55c2016af85b..a93b0666c6b864b59eaaf6779d57c774839c38b1 100755 (executable)
@@ -1700,6 +1700,105 @@ class Qwen2Model(Model):
     model_arch = gguf.MODEL_ARCH.QWEN2
 
 
+@Model.register("Qwen2MoeForCausalLM")
+class Qwen2MoeModel(Model):
+    model_arch = gguf.MODEL_ARCH.QWEN2MOE
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        if (n_experts := self.hparams.get("num_experts")) is not None:
+            self.gguf_writer.add_expert_count(n_experts)
+
+    def write_tensors(self):
+        block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
+        tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+        n_experts = self.hparams.get("num_experts")
+        experts = dict()
+        for name, data_torch in self.get_tensors():
+            # we don't need these
+            if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
+                continue
+
+            old_dtype = data_torch.dtype
+
+            # convert any unsupported data types to float32
+            if data_torch.dtype not in (torch.float16, torch.float32):
+                data_torch = data_torch.to(torch.float32)
+
+            data = data_torch.squeeze().numpy()
+
+            # process the experts separately
+            if name.find("experts") != -1:
+                experts[name] = data
+                if len(experts) >= n_experts * 3:
+                    # merge the experts into a single 3d tensor
+                    for bid in range(block_count):
+                        for w_name in ["down_proj", "gate_proj", "up_proj"]:
+                            full = True
+                            for xid in range(n_experts):
+                                ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+                                if ename not in experts:
+                                    full = False
+                                    break
+                            if not full:
+                                continue
+
+                            datas = []
+                            for xid in range(n_experts):
+                                ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+                                datas.append(experts[ename])
+                                del experts[ename]
+
+                            data = np.stack(datas, axis=0)
+                            data_dtype = data.dtype
+
+                            if self.ftype == 0 and data_dtype == np.float16:
+                                data = data.astype(np.float32)
+
+                            if self.ftype == 1 and data_dtype == np.float32:
+                                data = data.astype(np.float16)
+
+                            merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+
+                            new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
+                            if new_name is None:
+                                print(f"Can not map tensor {name!r}")
+                                sys.exit()
+
+                            print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
+
+                            self.gguf_writer.add_tensor(new_name, data)
+                continue
+
+            # map tensor names
+            new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+            if new_name is None:
+                print(f"Can not map tensor {name!r}")
+                sys.exit()
+
+            n_dims = len(data.shape)
+            data_dtype = data.dtype
+
+            # if f32 desired, convert any float16 to float32
+            if self.ftype == 0 and data_dtype == np.float16:
+                data = data.astype(np.float32)
+
+            # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+            if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
+                data = data.astype(np.float32)
+
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+                data = data.astype(np.float16)
+
+            print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
+
+            self.gguf_writer.add_tensor(new_name, data)
+
+        if len(experts) > 0:
+            raise ValueError(f"Unprocessed experts: {experts.keys()}")
+
+
 @Model.register("GPT2LMHeadModel")
 class GPT2Model(Model):
     model_arch = gguf.MODEL_ARCH.GPT2
index 0207b787a0255abd7b53fd0618610936ff2e45de..ae6ddeacd347f80938fe692d8b271b9379eceea5 100644 (file)
@@ -41,8 +41,11 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_TANH,
     GGML_METAL_KERNEL_TYPE_RELU,
     GGML_METAL_KERNEL_TYPE_GELU,
+    GGML_METAL_KERNEL_TYPE_GELU_4,
     GGML_METAL_KERNEL_TYPE_GELU_QUICK,
+    GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
     GGML_METAL_KERNEL_TYPE_SILU,
+    GGML_METAL_KERNEL_TYPE_SILU_4,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
     GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
@@ -473,8 +476,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                      tanh,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                      relu,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                      gelu,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                    gelu_4,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                gelu_quick,             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,              gelu_quick_4,           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                      silu,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                    silu_4,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX,                  soft_max,               ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,                soft_max_4,             ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,             diag_mask_inf,          true);
@@ -1178,6 +1184,9 @@ static enum ggml_status ggml_metal_graph_compute(
                 } break;
                 case GGML_OP_UNARY:
                     switch (ggml_get_unary_op(gf->nodes[i])) {
+                        // we are not taking into account the strides, so for now require contiguous tensors
+                        GGML_ASSERT(ggml_is_contiguous(src0));
+
                         case GGML_UNARY_OP_TANH:
                             {
                                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
@@ -1204,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute(
                             } break;
                         case GGML_UNARY_OP_GELU:
                             {
-                                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+                                int64_t n = ggml_nelements(dst);
+
+                                id<MTLComputePipelineState> pipeline = nil;
+
+                                if (n % 4 == 0) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
+                                    n /= 4;
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+                                }
 
                                 [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                const int64_t n = ggml_nelements(dst);
-                                GGML_ASSERT(n % 4 == 0);
-
-                                [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             } break;
                         case GGML_UNARY_OP_GELU_QUICK:
                             {
-                                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+                                int64_t n = ggml_nelements(dst);
+
+                                id<MTLComputePipelineState> pipeline = nil;
+
+                                if (n % 4 == 0) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
+                                    n /= 4;
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+                                }
 
                                 [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                const int64_t n = ggml_nelements(dst);
-                                GGML_ASSERT(n % 4 == 0);
-
-                                [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             } break;
                         case GGML_UNARY_OP_SILU:
                             {
-                                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+                                int64_t n = ggml_nelements(dst);
+
+                                id<MTLComputePipelineState> pipeline = nil;
+
+                                if (n % 4 == 0) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
+                                    n /= 4;
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+                                }
 
                                 [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                const int64_t n = ggml_nelements(dst);
-                                GGML_ASSERT(n % 4 == 0);
-
-                                [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             } break;
                         default:
                             {
index 56748166ca4caf5d3ee6636a168cd2986ffd18de..82a8cad93c8bcb5a29767ee56af5e178bcb5627b 100644 (file)
@@ -242,6 +242,15 @@ constant float GELU_QUICK_COEF = -1.702f;
 constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
 
 kernel void kernel_gelu(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
     device const float4 * src0,
     device       float4 * dst,
     uint tpig[[thread_position_in_grid]]) {
@@ -255,6 +264,15 @@ kernel void kernel_gelu(
 }
 
 kernel void kernel_gelu_quick(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_gelu_quick_4(
     device const float4 * src0,
     device       float4 * dst,
     uint tpig[[thread_position_in_grid]]) {
@@ -264,6 +282,14 @@ kernel void kernel_gelu_quick(
 }
 
 kernel void kernel_silu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
         device const float4 * src0,
         device       float4 * dst,
         uint tpig[[thread_position_in_grid]]) {
index 1358206a3e5691e175c63709a322552f521eb2a9..df861164fe2f1ed1cf362693d611121af1a0dfc1 100644 (file)
@@ -120,6 +120,7 @@ class MODEL_ARCH(IntEnum):
     STABLELM   = auto()
     QWEN       = auto()
     QWEN2      = auto()
+    QWEN2MOE   = auto()
     PHI2       = auto()
     PLAMO      = auto()
     CODESHELL  = auto()
@@ -135,41 +136,45 @@ class MODEL_ARCH(IntEnum):
 
 
 class MODEL_TENSOR(IntEnum):
-    TOKEN_EMBD      = auto()
-    TOKEN_EMBD_NORM = auto()
-    TOKEN_TYPES     = auto()
-    POS_EMBD        = auto()
-    OUTPUT          = auto()
-    OUTPUT_NORM     = auto()
-    ROPE_FREQS      = auto()
-    ATTN_Q          = auto()
-    ATTN_K          = auto()
-    ATTN_V          = auto()
-    ATTN_QKV        = auto()
-    ATTN_OUT        = auto()
-    ATTN_NORM       = auto()
-    ATTN_NORM_2     = auto()
-    ATTN_OUT_NORM   = auto()
-    ATTN_ROT_EMBD   = auto()
-    FFN_GATE_INP    = auto()
-    FFN_NORM        = auto()
-    FFN_GATE        = auto()
-    FFN_DOWN        = auto()
-    FFN_UP          = auto()
-    FFN_ACT         = auto()
-    FFN_GATE_EXP    = auto()
-    FFN_DOWN_EXP    = auto()
-    FFN_UP_EXP      = auto()
-    ATTN_Q_NORM     = auto()
-    ATTN_K_NORM     = auto()
-    LAYER_OUT_NORM  = auto()
-    SSM_IN          = auto()
-    SSM_CONV1D      = auto()
-    SSM_X           = auto()
-    SSM_DT          = auto()
-    SSM_A           = auto()
-    SSM_D           = auto()
-    SSM_OUT         = auto()
+    TOKEN_EMBD         = auto()
+    TOKEN_EMBD_NORM    = auto()
+    TOKEN_TYPES        = auto()
+    POS_EMBD           = auto()
+    OUTPUT             = auto()
+    OUTPUT_NORM        = auto()
+    ROPE_FREQS         = auto()
+    ATTN_Q             = auto()
+    ATTN_K             = auto()
+    ATTN_V             = auto()
+    ATTN_QKV           = auto()
+    ATTN_OUT           = auto()
+    ATTN_NORM          = auto()
+    ATTN_NORM_2        = auto()
+    ATTN_OUT_NORM      = auto()
+    ATTN_ROT_EMBD      = auto()
+    FFN_GATE_INP       = auto()
+    FFN_GATE_INP_SHEXP = auto()
+    FFN_NORM           = auto()
+    FFN_GATE           = auto()
+    FFN_DOWN           = auto()
+    FFN_UP             = auto()
+    FFN_ACT            = auto()
+    FFN_GATE_EXP       = auto()
+    FFN_DOWN_EXP       = auto()
+    FFN_UP_EXP         = auto()
+    FFN_GATE_SHEXP     = auto()
+    FFN_DOWN_SHEXP     = auto()
+    FFN_UP_SHEXP       = auto()
+    ATTN_Q_NORM        = auto()
+    ATTN_K_NORM        = auto()
+    LAYER_OUT_NORM     = auto()
+    SSM_IN             = auto()
+    SSM_CONV1D         = auto()
+    SSM_X              = auto()
+    SSM_DT             = auto()
+    SSM_A              = auto()
+    SSM_D              = auto()
+    SSM_OUT            = auto()
 
 
 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -190,6 +195,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.STABLELM:       "stablelm",
     MODEL_ARCH.QWEN:           "qwen",
     MODEL_ARCH.QWEN2:          "qwen2",
+    MODEL_ARCH.QWEN2MOE:       "qwen2moe",
     MODEL_ARCH.PHI2:           "phi2",
     MODEL_ARCH.PLAMO:          "plamo",
     MODEL_ARCH.CODESHELL:      "codeshell",
@@ -205,41 +211,45 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
-    MODEL_TENSOR.TOKEN_EMBD:      "token_embd",
-    MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
-    MODEL_TENSOR.TOKEN_TYPES:     "token_types",
-    MODEL_TENSOR.POS_EMBD:        "position_embd",
-    MODEL_TENSOR.OUTPUT_NORM:     "output_norm",
-    MODEL_TENSOR.OUTPUT:          "output",
-    MODEL_TENSOR.ROPE_FREQS:      "rope_freqs",
-    MODEL_TENSOR.ATTN_NORM:       "blk.{bid}.attn_norm",
-    MODEL_TENSOR.ATTN_NORM_2:     "blk.{bid}.attn_norm_2",
-    MODEL_TENSOR.ATTN_QKV:        "blk.{bid}.attn_qkv",
-    MODEL_TENSOR.ATTN_Q:          "blk.{bid}.attn_q",
-    MODEL_TENSOR.ATTN_K:          "blk.{bid}.attn_k",
-    MODEL_TENSOR.ATTN_V:          "blk.{bid}.attn_v",
-    MODEL_TENSOR.ATTN_OUT:        "blk.{bid}.attn_output",
-    MODEL_TENSOR.ATTN_ROT_EMBD:   "blk.{bid}.attn_rot_embd",
-    MODEL_TENSOR.ATTN_Q_NORM:     "blk.{bid}.attn_q_norm",
-    MODEL_TENSOR.ATTN_K_NORM:     "blk.{bid}.attn_k_norm",
-    MODEL_TENSOR.ATTN_OUT_NORM:   "blk.{bid}.attn_output_norm",
-    MODEL_TENSOR.FFN_GATE_INP:    "blk.{bid}.ffn_gate_inp",
-    MODEL_TENSOR.FFN_NORM:        "blk.{bid}.ffn_norm",
-    MODEL_TENSOR.FFN_GATE:        "blk.{bid}.ffn_gate",
-    MODEL_TENSOR.FFN_DOWN:        "blk.{bid}.ffn_down",
-    MODEL_TENSOR.FFN_UP:          "blk.{bid}.ffn_up",
-    MODEL_TENSOR.FFN_ACT:         "blk.{bid}.ffn",
-    MODEL_TENSOR.FFN_GATE_EXP:    "blk.{bid}.ffn_gate_exps",
-    MODEL_TENSOR.FFN_DOWN_EXP:    "blk.{bid}.ffn_down_exps",
-    MODEL_TENSOR.FFN_UP_EXP:      "blk.{bid}.ffn_up_exps",
-    MODEL_TENSOR.LAYER_OUT_NORM:  "blk.{bid}.layer_output_norm",
-    MODEL_TENSOR.SSM_IN:          "blk.{bid}.ssm_in",
-    MODEL_TENSOR.SSM_CONV1D:      "blk.{bid}.ssm_conv1d",
-    MODEL_TENSOR.SSM_X:           "blk.{bid}.ssm_x",
-    MODEL_TENSOR.SSM_DT:          "blk.{bid}.ssm_dt",
-    MODEL_TENSOR.SSM_A:           "blk.{bid}.ssm_a",
-    MODEL_TENSOR.SSM_D:           "blk.{bid}.ssm_d",
-    MODEL_TENSOR.SSM_OUT:         "blk.{bid}.ssm_out",
+    MODEL_TENSOR.TOKEN_EMBD:         "token_embd",
+    MODEL_TENSOR.TOKEN_EMBD_NORM:    "token_embd_norm",
+    MODEL_TENSOR.TOKEN_TYPES:        "token_types",
+    MODEL_TENSOR.POS_EMBD:           "position_embd",
+    MODEL_TENSOR.OUTPUT_NORM:        "output_norm",
+    MODEL_TENSOR.OUTPUT:             "output",
+    MODEL_TENSOR.ROPE_FREQS:         "rope_freqs",
+    MODEL_TENSOR.ATTN_NORM:          "blk.{bid}.attn_norm",
+    MODEL_TENSOR.ATTN_NORM_2:        "blk.{bid}.attn_norm_2",
+    MODEL_TENSOR.ATTN_QKV:           "blk.{bid}.attn_qkv",
+    MODEL_TENSOR.ATTN_Q:             "blk.{bid}.attn_q",
+    MODEL_TENSOR.ATTN_K:             "blk.{bid}.attn_k",
+    MODEL_TENSOR.ATTN_V:             "blk.{bid}.attn_v",
+    MODEL_TENSOR.ATTN_OUT:           "blk.{bid}.attn_output",
+    MODEL_TENSOR.ATTN_ROT_EMBD:      "blk.{bid}.attn_rot_embd",
+    MODEL_TENSOR.ATTN_Q_NORM:        "blk.{bid}.attn_q_norm",
+    MODEL_TENSOR.ATTN_K_NORM:        "blk.{bid}.attn_k_norm",
+    MODEL_TENSOR.ATTN_OUT_NORM:      "blk.{bid}.attn_output_norm",
+    MODEL_TENSOR.FFN_GATE_INP:       "blk.{bid}.ffn_gate_inp",
+    MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
+    MODEL_TENSOR.FFN_NORM:           "blk.{bid}.ffn_norm",
+    MODEL_TENSOR.FFN_GATE:           "blk.{bid}.ffn_gate",
+    MODEL_TENSOR.FFN_DOWN:           "blk.{bid}.ffn_down",
+    MODEL_TENSOR.FFN_UP:             "blk.{bid}.ffn_up",
+    MODEL_TENSOR.FFN_GATE_SHEXP:     "blk.{bid}.ffn_gate_shexp",
+    MODEL_TENSOR.FFN_DOWN_SHEXP:     "blk.{bid}.ffn_down_shexp",
+    MODEL_TENSOR.FFN_UP_SHEXP:       "blk.{bid}.ffn_up_shexp",
+    MODEL_TENSOR.FFN_ACT:            "blk.{bid}.ffn",
+    MODEL_TENSOR.FFN_GATE_EXP:       "blk.{bid}.ffn_gate_exps",
+    MODEL_TENSOR.FFN_DOWN_EXP:       "blk.{bid}.ffn_down_exps",
+    MODEL_TENSOR.FFN_UP_EXP:         "blk.{bid}.ffn_up_exps",
+    MODEL_TENSOR.LAYER_OUT_NORM:     "blk.{bid}.layer_output_norm",
+    MODEL_TENSOR.SSM_IN:             "blk.{bid}.ssm_in",
+    MODEL_TENSOR.SSM_CONV1D:         "blk.{bid}.ssm_conv1d",
+    MODEL_TENSOR.SSM_X:              "blk.{bid}.ssm_x",
+    MODEL_TENSOR.SSM_DT:             "blk.{bid}.ssm_dt",
+    MODEL_TENSOR.SSM_A:              "blk.{bid}.ssm_a",
+    MODEL_TENSOR.SSM_D:              "blk.{bid}.ssm_d",
+    MODEL_TENSOR.SSM_OUT:            "blk.{bid}.ssm_out",
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -474,6 +484,25 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.QWEN2MOE: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_GATE_INP_SHEXP,
+        MODEL_TENSOR.FFN_GATE_SHEXP,
+        MODEL_TENSOR.FFN_DOWN_SHEXP,
+        MODEL_TENSOR.FFN_UP_SHEXP,
+    ],
     MODEL_ARCH.PLAMO: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index ec6fcbb838425633f70e76aef5865a0a385f335a..10de36fa883ee54b7ecb4d12a968695756cc91e6 100644 (file)
@@ -208,10 +208,15 @@ class TensorNameMap:
         MODEL_TENSOR.FFN_GATE_INP: (
             "layers.{bid}.feed_forward.gate",             # mixtral
             "model.layers.{bid}.block_sparse_moe.gate",   # mixtral
+            "model.layers.{bid}.mlp.gate",                # qwen2moe
             "transformer.decoder_layer.{bid}.router",     # Grok
             "transformer.blocks.{bid}.ffn.router.layer",  # dbrx
         ),
 
+        MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
+            "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
+        ),
+
         # Feed-forward up
         MODEL_TENSOR.FFN_UP: (
             "gpt_neox.layers.{bid}.mlp.dense_h_to_4h",                # gptneox
@@ -236,9 +241,14 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
-            "layers.{bid}.feed_forward.experts.w3",                 # mixtral (merged)
-            "transformer.decoder_layer.{bid}.moe.linear_v",         # Grok (merged)
-            "transformer.blocks.{bid}.ffn.experts.mlp.v1",          # dbrx
+            "layers.{bid}.feed_forward.experts.w3",          # mixtral (merged)
+            "transformer.decoder_layer.{bid}.moe.linear_v",  # Grok (merged)
+            "transformer.blocks.{bid}.ffn.experts.mlp.v1",   # dbrx
+            "model.layers.{bid}.mlp.experts.up_proj",        # qwen2moe (merged)
+        ),
+
+        MODEL_TENSOR.FFN_UP_SHEXP: (
+            "model.layers.{bid}.mlp.shared_expert.up_proj",  # qwen2moe
         ),
 
         # AWQ-activation gate
@@ -260,6 +270,11 @@ class TensorNameMap:
             "layers.{bid}.feed_forward.experts.w1",         # mixtral (merged)
             "transformer.decoder_layer.{bid}.moe.linear",   # Grok (merged)
             "transformer.blocks.{bid}.ffn.experts.mlp.w1",  # dbrx
+            "model.layers.{bid}.mlp.experts.gate_proj",     # qwen2moe (merged)
+        ),
+
+        MODEL_TENSOR.FFN_GATE_SHEXP: (
+            "model.layers.{bid}.mlp.shared_expert.gate_proj",  # qwen2moe
         ),
 
         # Feed-forward down
@@ -285,9 +300,14 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
-            "layers.{bid}.feed_forward.experts.w2",                 # mixtral (merged)
-            "transformer.decoder_layer.{bid}.moe.linear_1",         # Grok (merged)
-            "transformer.blocks.{bid}.ffn.experts.mlp.w2",          # dbrx
+            "layers.{bid}.feed_forward.experts.w2",          # mixtral (merged)
+            "transformer.decoder_layer.{bid}.moe.linear_1",  # Grok (merged)
+            "transformer.blocks.{bid}.ffn.experts.mlp.w2",   # dbrx
+            "model.layers.{bid}.mlp.experts.down_proj",      # qwen2moe (merged)
+        ),
+
+        MODEL_TENSOR.FFN_DOWN_SHEXP: (
+            "model.layers.{bid}.mlp.shared_expert.down_proj",  # qwen2moe
         ),
 
         MODEL_TENSOR.ATTN_Q_NORM: (
@@ -366,7 +386,7 @@ class TensorNameMap:
                 if tensor not in MODEL_TENSORS[arch]:
                     continue
                 # TODO: make this configurable
-                n_experts = 8
+                n_experts = 60
                 for xid in range(n_experts):
                     tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
                     self.mapping[tensor_name] = (tensor, tensor_name)
index 38e5936254e4a40c156818d28a6a827728a15b52..340e68fde3ac9441c9e5eb011784c4ffa14d63ed 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
 #endif
 
 #define LLAMA_MAX_NODES   8192
-#define LLAMA_MAX_EXPERTS 16
+#define LLAMA_MAX_EXPERTS 60
 
 
 //
@@ -209,6 +209,7 @@ enum llm_arch {
     LLM_ARCH_STABLELM,
     LLM_ARCH_QWEN,
     LLM_ARCH_QWEN2,
+    LLM_ARCH_QWEN2MOE,
     LLM_ARCH_PHI2,
     LLM_ARCH_PLAMO,
     LLM_ARCH_CODESHELL,
@@ -242,6 +243,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_STABLELM,        "stablelm"   },
     { LLM_ARCH_QWEN,            "qwen"       },
     { LLM_ARCH_QWEN2,           "qwen2"      },
+    { LLM_ARCH_QWEN2MOE,        "qwen2moe"   },
     { LLM_ARCH_PHI2,            "phi2"       },
     { LLM_ARCH_PLAMO,           "plamo"      },
     { LLM_ARCH_CODESHELL,       "codeshell"  },
@@ -437,6 +439,7 @@ enum llm_tensor {
     LLM_TENSOR_ATTN_OUT_NORM,
     LLM_TENSOR_ATTN_ROT_EMBD,
     LLM_TENSOR_FFN_GATE_INP,
+    LLM_TENSOR_FFN_GATE_INP_SHEXP,
     LLM_TENSOR_FFN_NORM,
     LLM_TENSOR_FFN_GATE,
     LLM_TENSOR_FFN_DOWN,
@@ -448,6 +451,9 @@ enum llm_tensor {
     LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
     LLM_TENSOR_FFN_GATE_EXPS,
     LLM_TENSOR_FFN_UP_EXPS,
+    LLM_TENSOR_FFN_DOWN_SHEXP,
+    LLM_TENSOR_FFN_GATE_SHEXP,
+    LLM_TENSOR_FFN_UP_SHEXP,
     LLM_TENSOR_ATTN_Q_NORM,
     LLM_TENSOR_ATTN_K_NORM,
     LLM_TENSOR_LAYER_OUT_NORM,
@@ -745,6 +751,28 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_QWEN2MOE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
+            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
+            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
+            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
+        },
+    },
     {
         LLM_ARCH_PHI2,
         {
@@ -1731,6 +1759,7 @@ enum e_model {
     MODEL_MEDIUM,
     MODEL_LARGE,
     MODEL_XL,
+    MODEL_A2_7B,
     MODEL_8x7B,
     MODEL_8x22B,
     MODEL_16x12B,
@@ -1917,6 +1946,12 @@ struct llama_layer {
     struct ggml_tensor * ffn_down_exps;
     struct ggml_tensor * ffn_up_exps ;
 
+    // ff shared expert (shexp)
+    struct ggml_tensor * ffn_gate_inp_shexp;
+    struct ggml_tensor * ffn_gate_shexp;
+    struct ggml_tensor * ffn_down_shexp;
+    struct ggml_tensor * ffn_up_shexp;
+
     // ff bias
     struct ggml_tensor * ffn_down_b; // b2
     struct ggml_tensor * ffn_up_b;   // b3
@@ -3587,6 +3622,7 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_MEDIUM: return "0.4B";
         case MODEL_LARGE:  return "0.8B";
         case MODEL_XL:     return "1.5B";
+        case MODEL_A2_7B:  return "A2.7B";
         case MODEL_8x7B:   return "8x7B";
         case MODEL_8x22B:  return "8x22B";
         case MODEL_16x12B: return "16x12B";
@@ -3886,6 +3922,14 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_QWEN2MOE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                switch (hparams.n_layer) {
+                    case 24: model.type = e_model::MODEL_A2_7B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_PHI2:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -5156,6 +5200,54 @@ static bool llm_load_tensors(
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                     }
                 } break;
+            case LLM_ARCH_QWEN2MOE:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        // optional bias tensors
+                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd});
+                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa});
+                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
+
+                        GGML_ASSERT(hparams.n_expert      > 0);
+                        GGML_ASSERT(hparams.n_expert_used > 0);
+
+                        // MoE branch
+                        auto n_ff_exp = n_ff / hparams.n_expert_used;
+                        layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert});
+                        layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert});
+                        layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert});
+
+                        // Shared expert branch
+                        layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
+                        layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd,   n_ff});
+                    }
+                } break;
             case LLM_ARCH_PHI2:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -6532,7 +6624,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
+                cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
@@ -6565,7 +6657,7 @@ struct llm_build_context {
     }
 
     // REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505
-    ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, int il) {
+    ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, bool norm_w, int il) {
         ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
         cb(logits, "ffn_moe_logits", il);
 
@@ -6582,11 +6674,13 @@ struct llm_build_context {
 
         weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
 
-        ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
-        cb(weights_sum, "ffn_moe_weights_sum", il);
+        if (norm_w) {
+            ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
+            cb(weights_sum, "ffn_moe_weights_sum", il);
 
-        weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
-        cb(weights, "ffn_moe_weights_norm", il);
+            weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
+            cb(weights, "ffn_moe_weights_norm", il);
+        }
 
         // compute expert outputs
         ggml_tensor * moe_out = nullptr;
@@ -7083,7 +7177,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, il);
+            cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, true, il);
 
             // Grok
             // if layer_out_norm is present then apply it before adding the input
@@ -7219,7 +7313,7 @@ struct llm_build_context {
                                  LLM_NORM, cb, il);
             cb(cur, "attn_out_norm", il);
 
-            cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il);
+            cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il);
 
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
@@ -8434,6 +8528,141 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_qwen2moe() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            ggml_tensor * moe_out = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, false, il);
+
+            // FFN shared expert
+            {
+                ggml_tensor * cur_gate_inp = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
+                cb(cur_gate_inp, "ffn_shexp_gate_inp", il);
+
+                // sigmoid
+                ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp);
+                cb(cur_gate, "ffn_shexp_gate", il);
+
+                ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up_shexp,   NULL,
+                        model.layers[il].ffn_gate_shexp, NULL,
+                        model.layers[il].ffn_down_shexp, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur_ffn, "ffn_shexp", il);
+
+                ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate);
+                cb(ffn_shexp_out, "ffn_shexp_out", il);
+
+                moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out);
+                cb(moe_out, "ffn_out", il);
+
+                cur = moe_out;
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_phi2() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -9917,6 +10146,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_qwen2();
             } break;
+        case LLM_ARCH_QWEN2MOE:
+            {
+                result = llm.build_qwen2moe();
+            } break;
         case LLM_ARCH_PHI2:
             {
                 result = llm.build_phi2();
@@ -14834,6 +15067,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_STABLELM:
         case LLM_ARCH_QWEN:
         case LLM_ARCH_QWEN2:
+        case LLM_ARCH_QWEN2MOE:
         case LLM_ARCH_PHI2:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_STARCODER2:
index b50675952b51dd47a663eb4ff7cdcf3c5b28e6b3..21adba42e31481bce047d22d49abffdc02dcaa58 100644 (file)
@@ -1878,6 +1878,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     // unary ops
     for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
         test_cases.emplace_back(new test_unary((ggml_unary_op) op));
+        test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }));
     }
 
     test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));