]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : consolidate unary ops (#19490)
authorGeorgi Gerganov <redacted>
Wed, 11 Feb 2026 05:51:12 +0000 (07:51 +0200)
committerGitHub <redacted>
Wed, 11 Feb 2026 05:51:12 +0000 (07:51 +0200)
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal-ops.h
ggml/src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp

index 4c4c3ce36c4605bf8319264112b0484a8a192cd4..949e344cc8c3f0e420df7c3f31c9894b1ac53573 100644 (file)
@@ -212,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
 }
 
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
-    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-
     char base[256];
     char name[256];
 
-    const int64_t n = ggml_nelements(op);
+    int op_num = -1;
 
-    const char * op_str = "undefined";
     switch (op->op) {
-        case GGML_OP_SCALE:      op_str = "scale";      break;
-        case GGML_OP_FILL:       op_str = "fill";       break;
-        case GGML_OP_CLAMP:      op_str = "clamp";      break;
-        case GGML_OP_SQR:        op_str = "sqr";        break;
-        case GGML_OP_SQRT:       op_str = "sqrt";       break;
-        case GGML_OP_SIN:        op_str = "sin";        break;
-        case GGML_OP_COS:        op_str = "cos";        break;
-        case GGML_OP_LOG:        op_str = "log";        break;
-        case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
+        case GGML_OP_SCALE:      op_num = OP_UNARY_NUM_SCALE;      break;
+        case GGML_OP_FILL:       op_num = OP_UNARY_NUM_FILL;       break;
+        case GGML_OP_CLAMP:      op_num = OP_UNARY_NUM_CLAMP;      break;
+        case GGML_OP_SQR:        op_num = OP_UNARY_NUM_SQR;        break;
+        case GGML_OP_SQRT:       op_num = OP_UNARY_NUM_SQRT;       break;
+        case GGML_OP_SIN:        op_num = OP_UNARY_NUM_SIN;        break;
+        case GGML_OP_COS:        op_num = OP_UNARY_NUM_COS;        break;
+        case GGML_OP_LOG:        op_num = OP_UNARY_NUM_LOG;        break;
+        case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_TANH:        op_str = "tanh";        break;
-                case GGML_UNARY_OP_RELU:        op_str = "relu";        break;
-                case GGML_UNARY_OP_SIGMOID:     op_str = "sigmoid";     break;
-                case GGML_UNARY_OP_GELU:        op_str = "gelu";        break;
-                case GGML_UNARY_OP_GELU_ERF:    op_str = "gelu_erf";    break;
-                case GGML_UNARY_OP_GELU_QUICK:  op_str = "gelu_quick";  break;
-                case GGML_UNARY_OP_SILU:        op_str = "silu";        break;
-                case GGML_UNARY_OP_ELU:         op_str = "elu";         break;
-                case GGML_UNARY_OP_NEG:         op_str = "neg";         break;
-                case GGML_UNARY_OP_ABS:         op_str = "abs";         break;
-                case GGML_UNARY_OP_SGN:         op_str = "sgn";         break;
-                case GGML_UNARY_OP_STEP:        op_str = "step";        break;
-                case GGML_UNARY_OP_HARDSWISH:   op_str = "hardswish";   break;
-                case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
-                case GGML_UNARY_OP_EXP:         op_str = "exp";         break;
-                case GGML_UNARY_OP_SOFTPLUS:    op_str = "softplus";    break;
-                case GGML_UNARY_OP_EXPM1:       op_str = "expm1";       break;
+                case GGML_UNARY_OP_TANH:        op_num = OP_UNARY_NUM_TANH;        break;
+                case GGML_UNARY_OP_RELU:        op_num = OP_UNARY_NUM_RELU;        break;
+                case GGML_UNARY_OP_SIGMOID:     op_num = OP_UNARY_NUM_SIGMOID;     break;
+                case GGML_UNARY_OP_GELU:        op_num = OP_UNARY_NUM_GELU;        break;
+                case GGML_UNARY_OP_GELU_ERF:    op_num = OP_UNARY_NUM_GELU_ERF;    break;
+                case GGML_UNARY_OP_GELU_QUICK:  op_num = OP_UNARY_NUM_GELU_QUICK;  break;
+                case GGML_UNARY_OP_SILU:        op_num = OP_UNARY_NUM_SILU;        break;
+                case GGML_UNARY_OP_ELU:         op_num = OP_UNARY_NUM_ELU;         break;
+                case GGML_UNARY_OP_NEG:         op_num = OP_UNARY_NUM_NEG;         break;
+                case GGML_UNARY_OP_ABS:         op_num = OP_UNARY_NUM_ABS;         break;
+                case GGML_UNARY_OP_SGN:         op_num = OP_UNARY_NUM_SGN;         break;
+                case GGML_UNARY_OP_STEP:        op_num = OP_UNARY_NUM_STEP;        break;
+                case GGML_UNARY_OP_HARDSWISH:   op_num = OP_UNARY_NUM_HARDSWISH;   break;
+                case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
+                case GGML_UNARY_OP_EXP:         op_num = OP_UNARY_NUM_EXP;         break;
+                case GGML_UNARY_OP_SOFTPLUS:    op_num = OP_UNARY_NUM_SOFTPLUS;    break;
+                case GGML_UNARY_OP_EXPM1:       op_num = OP_UNARY_NUM_EXPM1;       break;
                 default: GGML_ABORT("fatal error");
             } break;
         default: GGML_ABORT("fatal error");
     };
 
-    const char * suffix = "";
-    if (n % 4 == 0) {
-        suffix = "_4";
-    }
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t_str  = ggml_type_name(op->type);
 
-    snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
-    snprintf(name, 256, "%s", base);
+    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+    const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
+
+    snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
+    snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
+        ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
+    res.c4  = is_c4;
+    res.cnt = is_cnt;
+
     return res;
 }
 
index 891d70c85a452e0092dde44518ec4b0d48b01b4c..50a2a3e7f72afa7f514d34c4d3088043b07274a6 100644 (file)
@@ -1011,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
     }
 
     switch (op->op) {
+        case GGML_OP_SCALE:
+        case GGML_OP_FILL:
+        case GGML_OP_CLAMP:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
+        case GGML_OP_LOG:
+            return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_TANH:
@@ -1030,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 case GGML_UNARY_OP_EXP:
                 case GGML_UNARY_OP_SOFTPLUS:
                 case GGML_UNARY_OP_EXPM1:
-                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+                    return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                 default:
                     return false;
             }
@@ -1061,8 +1070,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
             return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ACC:
         case GGML_OP_REPEAT:
-        case GGML_OP_SCALE:
-        case GGML_OP_FILL:
         case GGML_OP_CONV_TRANSPOSE_1D:
             return true;
         case GGML_OP_CONV_TRANSPOSE_2D:
@@ -1070,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
                 op->src[1]->type == GGML_TYPE_F32 &&
                 op->type == GGML_TYPE_F32;
-        case GGML_OP_CLAMP:
-            return op->src[0]->type == GGML_TYPE_F32;
-        case GGML_OP_SQR:
-        case GGML_OP_SQRT:
-        case GGML_OP_SIN:
-        case GGML_OP_COS:
-        case GGML_OP_LOG:
-            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SUM:
             return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_TRI:
index 77bb403c15d2fce0579265b06820be613272cab7..44141f8e3d952bc030d1e30840f5e62e4541cfa7 100644 (file)
@@ -80,7 +80,8 @@
 #define FC_SSM_CONV                    900
 #define FC_SOLVE_TRI                   1000
 #define FC_COUNT_EQUAL                 1100
-#define FC_BIN                         1200
+#define FC_UNARY                       1200
+#define FC_BIN                         1300
 
 // op-specific constants
 #define OP_FLASH_ATTN_EXT_NQPSG 8
 #define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
 
+#define OP_UNARY_NUM_SCALE      10
+#define OP_UNARY_NUM_FILL       11
+#define OP_UNARY_NUM_CLAMP      12
+#define OP_UNARY_NUM_SQR        13
+#define OP_UNARY_NUM_SQRT       14
+#define OP_UNARY_NUM_SIN        15
+#define OP_UNARY_NUM_COS        16
+#define OP_UNARY_NUM_LOG        17
+#define OP_UNARY_NUM_LEAKY_RELU 18
+
+#define OP_UNARY_NUM_TANH        100
+#define OP_UNARY_NUM_RELU        101
+#define OP_UNARY_NUM_SIGMOID     102
+#define OP_UNARY_NUM_GELU        103
+#define OP_UNARY_NUM_GELU_ERF    104
+#define OP_UNARY_NUM_GELU_QUICK  105
+#define OP_UNARY_NUM_SILU        106
+#define OP_UNARY_NUM_ELU         107
+#define OP_UNARY_NUM_NEG         108
+#define OP_UNARY_NUM_ABS         109
+#define OP_UNARY_NUM_SGN         110
+#define OP_UNARY_NUM_STEP        111
+#define OP_UNARY_NUM_HARDSWISH   112
+#define OP_UNARY_NUM_HARDSIGMOID 113
+#define OP_UNARY_NUM_EXP         114
+#define OP_UNARY_NUM_SOFTPLUS    115
+#define OP_UNARY_NUM_EXPM1       116
+
+
 // kernel argument structs
 //
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -124,6 +154,31 @@ typedef struct {
     int32_t  dim;
 } ggml_metal_kargs_concat;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    float    slope;
+    float    scale;
+    float    bias;
+    float    val;
+    float    min;
+    float    max;
+} ggml_metal_kargs_unary;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -181,20 +236,6 @@ typedef struct {
     uint64_t nb3;
 } ggml_metal_kargs_repeat;
 
-typedef struct {
-    float scale;
-    float bias;
-} ggml_metal_kargs_scale;
-
-typedef struct {
-    float val;
-} ggml_metal_kargs_fill;
-
-typedef struct {
-    float min;
-    float max;
-} ggml_metal_kargs_clamp;
-
 typedef struct {
     int64_t  nk0;
     int64_t  ne00;
@@ -881,10 +922,6 @@ typedef struct {
     int      max_period;
 } ggml_metal_kargs_timestep_embedding;
 
-typedef struct {
-    float    slope;
-} ggml_metal_kargs_leaky_relu;
-
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
index dbf25433c259922d4530426b6097f8e335628eb3..b159a8e7fd017f12e457ddb90c193ea45ea14acf 100644 (file)
@@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
                 n_fuse = ggml_metal_op_acc(ctx, idx);
             } break;
         case GGML_OP_SCALE:
-            {
-                n_fuse = ggml_metal_op_scale(ctx, idx);
-            } break;
         case GGML_OP_FILL:
-            {
-                n_fuse = ggml_metal_op_fill(ctx, idx);
-            } break;
         case GGML_OP_CLAMP:
-            {
-                n_fuse = ggml_metal_op_clamp(ctx, idx);
-            } break;
+        case GGML_OP_LEAKY_RELU:
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_SIN:
@@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_top_k(ctx, idx);
             } break;
-        case GGML_OP_LEAKY_RELU:
-            {
-                n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
-            } break;
         case GGML_OP_TRI:
             {
                 n_fuse = ggml_metal_op_tri(ctx, idx);
@@ -722,7 +710,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
-int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
+int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
     ggml_metal_library_t lib = ctx->lib;
@@ -733,133 +721,80 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    float scale;
-    float bias;
-    memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
-    memcpy(&bias,  ((const int32_t *) op->op_params) + 1, sizeof(float));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
 
-    ggml_metal_kargs_scale args = {
-        /*.scale =*/ scale,
-        /*.bias  =*/ bias,
-    };
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
 
-    int64_t n = ggml_nelements(op);
+    ggml_metal_kargs_unary args = {
+        /*.ne00  =*/ ne00,
+        /*.ne01  =*/ ne01,
+        /*.ne02  =*/ ne02,
+        /*.ne03  =*/ ne03,
+        /*.nb00  =*/ nb00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne0   =*/ ne0,
+        /*.ne1   =*/ ne1,
+        /*.ne2   =*/ ne2,
+        /*.ne3   =*/ ne3,
+        /*.nb0   =*/ nb0,
+        /*.nb1   =*/ nb1,
+        /*.nb2   =*/ nb2,
+        /*.nb3   =*/ nb3,
+        /*.slope =*/ 0.0,
+        /*.scale =*/ 0.0,
+        /*.bias  =*/ 0.0,
+        /*.val   =*/ 0.0,
+        /*.min   =*/ 0.0,
+        /*.max   =*/ 0.0,
+    };
 
-    if (n % 4 == 0) {
-        n /= 4;
+    if (op->op == GGML_OP_LEAKY_RELU) {
+        args.slope = ggml_get_op_params_f32(op, 0);
     }
 
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
-    return 1;
-}
-
-int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    const float val = ggml_get_op_params_f32(op, 0);
-
-    ggml_metal_kargs_fill args = {
-        /*.val =*/ val
-    };
+    if (op->op == GGML_OP_SCALE) {
+        args.scale = ggml_get_op_params_f32(op, 0);
+        args.bias  = ggml_get_op_params_f32(op, 1);
+    }
 
-    int64_t n = ggml_nelements(op);
+    if (op->op == GGML_OP_FILL) {
+        args.val = ggml_get_op_params_f32(op, 0);
+    }
 
-    if (n % 4 == 0) {
-        n /= 4;
+    if (op->op == GGML_OP_CLAMP) {
+        args.min = ggml_get_op_params_f32(op, 0);
+        args.max = ggml_get_op_params_f32(op, 1);
     }
 
     auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
 
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
-    return 1;
-}
-
-int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    float min;
-    float max;
-    memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
-    memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
-
-    ggml_metal_kargs_clamp args = {
-        /*.min =*/ min,
-        /*.max =*/ max,
-    };
-
-    int64_t n = ggml_nelements(op);
-
-    if (n % 4 == 0) {
-        n /= 4;
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne0  = ne0/4;
     }
 
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+    if (pipeline.cnt) {
+        const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
 
-    return 1;
-}
+        ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+    } else {
+        const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
-int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
+        const int nth = MIN(args.ne00, nth_max);
 
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
+        const int nk0 = (args.ne00 + nth - 1)/nth;
 
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    int64_t n = ggml_nelements(op);
-
-    if (n % 4 == 0) {
-        n /= 4;
+        ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
     }
 
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         1);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
     return 1;
 }
 
@@ -4084,42 +4019,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
-int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    float slope;
-    memcpy(&slope, op->op_params, sizeof(float));
-
-    ggml_metal_kargs_leaky_relu args = {
-        /*.slope =*/ slope
-    };
-
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    int64_t n = ggml_nelements(op);
-
-    if (n % 4 == 0) {
-        n /= 4;
-    }
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
-    return 1;
-}
-
 int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
index 3c64e4f6007990b97bcab17368f3104cfe0c911d..29456d70d5ef16205c104066258b920811ae0ae4 100644 (file)
@@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
 int ggml_metal_op_concat            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_repeat            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_acc               (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_scale             (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_fill              (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_clamp             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_unary             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_glu               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_sum               (ggml_metal_op_t ctx, int idx);
@@ -86,7 +83,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argmax            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argsort           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_top_k             (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_leaky_relu        (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_tri               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_adamw    (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_sgd      (ggml_metal_op_t ctx, int idx);
index 35cc3bbdfdf1dfabff61affed44d2424301d55c9..7d841341a18638bdf4377d1725f09f15821c44b0 100644 (file)
@@ -895,6 +895,192 @@ enum ggml_sort_order {
     GGML_SORT_ORDER_DESC,
 };
 
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+constant float SQRT_2_INV      = 0.70710678118654752440084436210484f;
+
+// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
+// ref: https://www.johndcook.com/blog/python_erf/
+constant float p_erf  = 0.3275911f;
+constant float a1_erf = 0.254829592f;
+constant float a2_erf = -0.284496736f;
+constant float a3_erf = 1.421413741f;
+constant float a4_erf = -1.453152027f;
+constant float a5_erf = 1.061405429f;
+
+template<typename T>
+T erf_approx(T x) {
+    T sign_x = sign(x);
+    x = fabs(x);
+    T t = 1.0f / (1.0f + p_erf * x);
+    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
+    return sign_x * y;
+}
+
+constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
+constant bool  FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
+
+template <typename T0, typename T>
+kernel void kernel_unary_impl(
+        constant ggml_metal_kargs_unary & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+#define FC_OP  FC_unary_op
+#define FC_CNT FC_unary_cnt
+
+    device const T0 * src0_ptr;
+    device       T  * dst_ptr;
+
+    int i0;
+
+    if (FC_CNT) {
+        i0 = tgpig.x;
+
+        src0_ptr = (device const T0 *) (src0);
+        dst_ptr  = (device       T  *) (dst);
+    } else {
+        const int i03 = tgpig.z;
+        const int i02 = tgpig.y;
+        const int k0  = tgpig.x/args.ne01;
+        const int i01 = tgpig.x - k0*args.ne01;
+
+        i0 = k0*ntg.x + tpitg.x;
+
+        src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+        dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1 );
+    }
+
+    {
+        //threadgroup_barrier(mem_flags::mem_none);
+
+        if (!FC_CNT) {
+            if (i0 >= args.ne0) {
+                return;
+            }
+        }
+
+        device const T0 & x = src0_ptr[i0];
+
+        if (FC_OP == OP_UNARY_NUM_SCALE) {
+            dst_ptr[i0] = args.scale * x + args.bias;
+        }
+
+        if (FC_OP == OP_UNARY_NUM_FILL) {
+            dst_ptr[i0] = args.val;
+        }
+
+        if (FC_OP == OP_UNARY_NUM_CLAMP) {
+            dst_ptr[i0] = clamp(x, args.min, args.max);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SQR) {
+            dst_ptr[i0] = x * x;
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SQRT) {
+            dst_ptr[i0] = sqrt(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SIN) {
+            dst_ptr[i0] = sin(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_COS) {
+            dst_ptr[i0] = cos(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_LOG) {
+            dst_ptr[i0] = log(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
+            dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(x * args.slope);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_TANH) {
+            dst_ptr[i0] = precise::tanh(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_RELU) {
+            dst_ptr[i0] = fmax(0.0f, x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SIGMOID) {
+            dst_ptr[i0] = 1.0f / (1.0f + exp(-x));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_GELU) {
+            dst_ptr[i0] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
+            dst_ptr[i0] = 0.5f*x*(1.0f + erf_approx(SQRT_2_INV*x));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
+            dst_ptr[i0] = x * (1.0f/(1.0f + exp(GELU_QUICK_COEF*x)));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SILU) {
+            dst_ptr[i0] = x / (1.0f + exp(-x));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_ELU) {
+            dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(exp(x) - 1.0f);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_NEG) {
+            dst_ptr[i0] = -x;
+        }
+
+        if (FC_OP == OP_UNARY_NUM_ABS) {
+            dst_ptr[i0] = fabs(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SGN) {
+            dst_ptr[i0] = T(x > 0.0f) - T(x < 0.0f);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_STEP) {
+            dst_ptr[i0] = T(x > 0.0f);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
+            dst_ptr[i0] = x * fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
+            dst_ptr[i0] = fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_EXP) {
+            dst_ptr[i0] = exp(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
+            dst_ptr[i0] = select(log(1.0f + exp(x)), x, x > 20.0f);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_EXPM1) {
+            // TODO: precise implementation
+            dst_ptr[i0] = exp(x) - 1.0f;
+        }
+    }
+
+#undef FC_OP
+#undef FC_CNT
+}
+
+typedef decltype(kernel_unary_impl<float, float>) kernel_unary_t;
+
+template [[host_name("kernel_unary_f32_f32")]]   kernel kernel_unary_t kernel_unary_impl<float,  float>;
+template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4>;
+
+
 // OP: 0 - add, 1 - sub, 2 - mul, 3 - div
 constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
 constant short FC_bin_f  [[function_constant(FC_BIN + 1)]];
@@ -1114,414 +1300,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
 template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
 template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
 
-kernel void kernel_scale_f32(
-        constant ggml_metal_kargs_scale & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * args.scale + args.bias;
-}
-
-kernel void kernel_scale_f32_4(
-        constant ggml_metal_kargs_scale & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * args.scale + args.bias;
-}
-
-kernel void kernel_fill_f32(
-        constant ggml_metal_kargs_fill & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = args.val;
-}
-
-kernel void kernel_fill_f32_4(
-        constant ggml_metal_kargs_fill & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = args.val;
-}
-
-kernel void kernel_clamp_f32(
-        constant ggml_metal_kargs_clamp & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = clamp(src0[tpig], args.min, args.max);
-}
-
-kernel void kernel_clamp_f32_4(
-        constant ggml_metal_kargs_clamp & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = clamp(src0[tpig], args.min, args.max);
-}
-
-kernel void kernel_relu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_relu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_sigmoid_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_sigmoid_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_tanh_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = precise::tanh(src0[tpig]);
-}
-
-kernel void kernel_tanh_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = precise::tanh(src0[tpig]);
-}
-
-constant float GELU_COEF_A     = 0.044715f;
-constant float GELU_QUICK_COEF = -1.702f;
-constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
-constant float SQRT_2_INV      = 0.70710678118654752440084436210484f;
-
-kernel void kernel_gelu_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    // BEWARE !!!
-    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
-    // This was observed with Falcon 7B and 40B models
-    //
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_quick_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-kernel void kernel_gelu_quick_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
-// ref: https://www.johndcook.com/blog/python_erf/
-constant float p_erf  = 0.3275911f;
-constant float a1_erf = 0.254829592f;
-constant float a2_erf = -0.284496736f;
-constant float a3_erf = 1.421413741f;
-constant float a4_erf = -1.453152027f;
-constant float a5_erf = 1.061405429f;
-
-template<typename T>
-T erf_approx(T x) {
-    T sign_x = sign(x);
-    x = fabs(x);
-    T t = 1.0f / (1.0f + p_erf * x);
-    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
-    return sign_x * y;
-}
-
-kernel void kernel_gelu_erf_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
-}
-
-kernel void kernel_gelu_erf_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
-}
-
-kernel void kernel_silu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_silu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_elu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
-}
-
-kernel void kernel_elu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
-    dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
-    dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
-    dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
-}
-
-kernel void kernel_sqr_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src0[tpig];
-}
-
-kernel void kernel_sqr_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src0[tpig];
-}
-
-kernel void kernel_sqrt_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sqrt(src0[tpig]);
-}
-
-kernel void kernel_sqrt_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sqrt(src0[tpig]);
-}
-
-kernel void kernel_sin_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sin(src0[tpig]);
-}
-
-kernel void kernel_sin_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sin(src0[tpig]);
-}
-
-kernel void kernel_cos_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = cos(src0[tpig]);
-}
-
-kernel void kernel_cos_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = cos(src0[tpig]);
-}
-
-kernel void kernel_log_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = log(src0[tpig]);
-}
-
-kernel void kernel_log_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = log(src0[tpig]);
-}
-
-kernel void kernel_neg_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = -src0[tpig];
-}
-
-kernel void kernel_neg_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = -src0[tpig];
-}
-
-kernel void kernel_abs_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = fabs(src0[tpig]);
-}
-
-kernel void kernel_abs_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = fabs(src0[tpig]);
-}
-
-kernel void kernel_sgn_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sign(src0[tpig]);
-}
-
-kernel void kernel_sgn_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sign(src0[tpig]);
-}
-
-kernel void kernel_step_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = step(0.0f, src0[tpig]);
-}
-
-kernel void kernel_step_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = step(0.0f, src0[tpig]);
-}
-
-kernel void kernel_hardswish_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardswish_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardsigmoid_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardsigmoid_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_exp_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]);
-}
-
-kernel void kernel_exp_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]);
-}
-
-kernel void kernel_softplus_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
-}
-
-kernel void kernel_softplus_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
-}
-
-kernel void kernel_expm1_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]) - 1.0f;
-}
-
-kernel void kernel_expm1_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]) - 1.0f;
-}
-
 kernel void kernel_reglu_f32(
         constant ggml_metal_kargs_glu & args,
         device const char * src0,
@@ -5072,24 +4850,6 @@ kernel void kernel_argsort_merge_f32_i32(
 template [[host_name("kernel_argsort_merge_f32_i32_asc")]]  kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
 template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
 
-kernel void kernel_leaky_relu_f32(
-        constant     ggml_metal_kargs_leaky_relu & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = x > 0.0f ? x : x * args.slope;
-}
-
-kernel void kernel_leaky_relu_f32_4(
-        constant     ggml_metal_kargs_leaky_relu & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
-}
-
 constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
 
 constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
@@ -9939,7 +9699,7 @@ kernel void kernel_opt_step_sgd_f32(
 
 template<typename T>
 kernel void kernel_memset(
-        constant ggml_metal_kargs_fill & args,
+        constant ggml_metal_kargs_memset & args,
         device T * dst,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = args.val;
index 3bdb3d7ba9dac289478fdae359c875ec2ffb4f10..1f40d2fc38688c0ad7f16952aa6dd55cce9e92e1 100644 (file)
@@ -7882,20 +7882,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_round     (type));
         test_cases.emplace_back(new test_trunc     (type));
         test_cases.emplace_back(new test_sqr       (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_sqr       (type, {1024, 1024, 1, 1}));
         test_cases.emplace_back(new test_sqrt      (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_sqrt      (type, {1024, 1024, 1, 1}));
         test_cases.emplace_back(new test_log       (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_log       (type, {1024, 1024, 1, 1}));
         test_cases.emplace_back(new test_sin       (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_sin       (type, {1024, 1024, 1, 1}));
         test_cases.emplace_back(new test_cos       (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_cos       (type, {1024, 1024, 1, 1}));
         test_cases.emplace_back(new test_clamp     (type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_clamp     (type, {1024, 1024, 1, 1}));
         test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
+        test_cases.emplace_back(new test_leaky_relu(type, {1024, 1024, 1, 1}));
         test_cases.emplace_back(new test_floor     (type, {7, 1, 5, 3}));
-        test_cases.emplace_back(new test_floor     (type, { 1024, 1024, 1, 1 }));
+        test_cases.emplace_back(new test_floor     (type, {1024, 1024, 1, 1}));
         test_cases.emplace_back(new test_ceil      (type, {7, 1, 5, 3}));
-        test_cases.emplace_back(new test_ceil      (type, { 1024, 1024, 1, 1 }));
+        test_cases.emplace_back(new test_ceil      (type, {1024, 1024, 1, 1}));
         test_cases.emplace_back(new test_round     (type, {7, 1, 5, 3}));
-        test_cases.emplace_back(new test_round     (type, { 1024, 1024, 1, 1 }));
+        test_cases.emplace_back(new test_round     (type, {1024, 1024, 1, 1}));
         test_cases.emplace_back(new test_trunc     (type, {7, 1, 5, 3}));
-        test_cases.emplace_back(new test_trunc     (type, { 1024, 1024, 1, 1 }));
+        test_cases.emplace_back(new test_trunc     (type, {1024, 1024, 1, 1}));
     }
 
     test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));