]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
llama : add qwen2moe (llama/6074)
authorShijie <redacted>
Tue, 16 Apr 2024 15:40:48 +0000 (23:40 +0800)
committerGeorgi Gerganov <redacted>
Mon, 13 May 2024 08:02:26 +0000 (11:02 +0300)
* support qwen2moe

* fix-review

* metal : support unary ops for nelements % 4 != 0

* metal : require contiguousness for float4 unary kernels

* metal : require contiguousness for float4 unary kernels (cont)

* fix-review

* names : for brevity "SHARED_EXP" -> "SHEXP"

* llama : reuse build_moe_ffn()

* llama : add model type name

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml-metal.m
ggml-metal.metal

index b43dfc3931d73503214e01f7fb2740aba119a74d..0ec47febbd20c4909444485e7bf835970bc41515 100644 (file)
@@ -42,8 +42,11 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_RELU,
     GGML_METAL_KERNEL_TYPE_SIGMOID,
     GGML_METAL_KERNEL_TYPE_GELU,
+    GGML_METAL_KERNEL_TYPE_GELU_4,
     GGML_METAL_KERNEL_TYPE_GELU_QUICK,
+    GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
     GGML_METAL_KERNEL_TYPE_SILU,
+    GGML_METAL_KERNEL_TYPE_SILU_4,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
     GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
@@ -475,8 +478,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                      relu,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                   sigmoid,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                      gelu,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                    gelu_4,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                gelu_quick,             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,              gelu_quick_4,           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                      silu,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                    silu_4,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX,                  soft_max,               ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,                soft_max_4,             ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,             diag_mask_inf,          true);
@@ -1181,6 +1187,9 @@ static enum ggml_status ggml_metal_graph_compute(
                 } break;
                 case GGML_OP_UNARY:
                     switch (ggml_get_unary_op(gf->nodes[i])) {
+                        // we are not taking into account the strides, so for now require contiguous tensors
+                        GGML_ASSERT(ggml_is_contiguous(src0));
+
                         case GGML_UNARY_OP_TANH:
                             {
                                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
@@ -1219,42 +1228,60 @@ static enum ggml_status ggml_metal_graph_compute(
                             } break;
                         case GGML_UNARY_OP_GELU:
                             {
-                                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+                                int64_t n = ggml_nelements(dst);
+
+                                id<MTLComputePipelineState> pipeline = nil;
+
+                                if (n % 4 == 0) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
+                                    n /= 4;
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+                                }
 
                                 [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                const int64_t n = ggml_nelements(dst);
-                                GGML_ASSERT(n % 4 == 0);
-
-                                [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             } break;
                         case GGML_UNARY_OP_GELU_QUICK:
                             {
-                                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+                                int64_t n = ggml_nelements(dst);
+
+                                id<MTLComputePipelineState> pipeline = nil;
+
+                                if (n % 4 == 0) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
+                                    n /= 4;
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+                                }
 
                                 [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                const int64_t n = ggml_nelements(dst);
-                                GGML_ASSERT(n % 4 == 0);
-
-                                [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             } break;
                         case GGML_UNARY_OP_SILU:
                             {
-                                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+                                int64_t n = ggml_nelements(dst);
+
+                                id<MTLComputePipelineState> pipeline = nil;
+
+                                if (n % 4 == 0) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
+                                    n /= 4;
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+                                }
 
                                 [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                const int64_t n = ggml_nelements(dst);
-                                GGML_ASSERT(n % 4 == 0);
-
-                                [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             } break;
                         default:
                             {
index 1d05087f4c1c4f223152c76afe3d84793583a613..d7ae37206f4fd94cd916f55b0f72577aefe108aa 100644 (file)
@@ -249,6 +249,15 @@ constant float GELU_QUICK_COEF = -1.702f;
 constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
 
 kernel void kernel_gelu(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
     device const float4 * src0,
     device       float4 * dst,
     uint tpig[[thread_position_in_grid]]) {
@@ -262,6 +271,15 @@ kernel void kernel_gelu(
 }
 
 kernel void kernel_gelu_quick(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_gelu_quick_4(
     device const float4 * src0,
     device       float4 * dst,
     uint tpig[[thread_position_in_grid]]) {
@@ -271,6 +289,14 @@ kernel void kernel_gelu_quick(
 }
 
 kernel void kernel_silu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
++kernel void kernel_silu_4(
         device const float4 * src0,
         device       float4 * dst,
         uint tpig[[thread_position_in_grid]]) {