]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Added support for GGML_OP_CLAMP in Metal (llama/6662)
authorDave <redacted>
Sun, 14 Apr 2024 11:14:19 +0000 (07:14 -0400)
committerGeorgi Gerganov <redacted>
Mon, 13 May 2024 08:02:26 +0000 (11:02 +0300)
* Added support for GGML_OP_CLAMP in Metal

* Corrected size

---------

Co-authored-by: dave-fl <redacted>
ggml-metal.m
ggml-metal.metal

index 7f0f1f1f1ce96e1c331ddf88dfbde3d717b176d1..b43dfc3931d73503214e01f7fb2740aba119a74d 100644 (file)
@@ -37,6 +37,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_DIV_ROW,
     GGML_METAL_KERNEL_TYPE_SCALE,
     GGML_METAL_KERNEL_TYPE_SCALE_4,
+    GGML_METAL_KERNEL_TYPE_CLAMP,
     GGML_METAL_KERNEL_TYPE_TANH,
     GGML_METAL_KERNEL_TYPE_RELU,
     GGML_METAL_KERNEL_TYPE_SIGMOID,
@@ -469,6 +470,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                   div_row,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                     scale,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                   scale_4,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                     clamp,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                      tanh,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                      relu,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                   sigmoid,                true);
@@ -716,6 +718,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_SCALE:
+        case GGML_OP_CLAMP:
         case GGML_OP_SQR:
         case GGML_OP_SUM_ROWS:
             return true;
@@ -1157,6 +1160,25 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                     } break;
+                case GGML_OP_CLAMP:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
+
+                    float min;
+                    float max;
+                    memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
+                    memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
+                    [encoder setBytes:&min length:sizeof(min) atIndex:2];
+                    [encoder setBytes:&max length:sizeof(max) atIndex:3];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
                 case GGML_OP_UNARY:
                     switch (ggml_get_unary_op(gf->nodes[i])) {
                         case GGML_UNARY_OP_TANH:
index 79cce21ff4f26b4ec5911dafe2eae4ab676f9c11..1d05087f4c1c4f223152c76afe3d84793583a613 100644 (file)
@@ -213,6 +213,15 @@ kernel void kernel_scale_4(
     dst[tpig] = src0[tpig] * scale;
 }
 
+kernel void kernel_clamp(
+        device const float * src0,
+        device       float * dst,
+        constant     float & min,
+        constant     float & max,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
+}
+
 kernel void kernel_relu(
         device const float * src0,
         device       float * dst,