}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
- GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-
char base[256];
char name[256];
- const int64_t n = ggml_nelements(op);
+ int op_num = -1;
- const char * op_str = "undefined";
switch (op->op) {
- case GGML_OP_SCALE: op_str = "scale"; break;
- case GGML_OP_FILL: op_str = "fill"; break;
- case GGML_OP_CLAMP: op_str = "clamp"; break;
- case GGML_OP_SQR: op_str = "sqr"; break;
- case GGML_OP_SQRT: op_str = "sqrt"; break;
- case GGML_OP_SIN: op_str = "sin"; break;
- case GGML_OP_COS: op_str = "cos"; break;
- case GGML_OP_LOG: op_str = "log"; break;
- case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
+ case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
+ case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
+ case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
+ case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
+ case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
+ case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
+ case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
+ case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
+ case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
- case GGML_UNARY_OP_TANH: op_str = "tanh"; break;
- case GGML_UNARY_OP_RELU: op_str = "relu"; break;
- case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
- case GGML_UNARY_OP_GELU: op_str = "gelu"; break;
- case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
- case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
- case GGML_UNARY_OP_SILU: op_str = "silu"; break;
- case GGML_UNARY_OP_ELU: op_str = "elu"; break;
- case GGML_UNARY_OP_NEG: op_str = "neg"; break;
- case GGML_UNARY_OP_ABS: op_str = "abs"; break;
- case GGML_UNARY_OP_SGN: op_str = "sgn"; break;
- case GGML_UNARY_OP_STEP: op_str = "step"; break;
- case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
- case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
- case GGML_UNARY_OP_EXP: op_str = "exp"; break;
- case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
- case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
+ case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
+ case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
+ case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
+ case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
+ case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
+ case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
+ case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
+ case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
+ case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
+ case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
+ case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
+ case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
+ case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
+ case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
+ case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
+ case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
+ case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
};
- const char * suffix = "";
- if (n % 4 == 0) {
- suffix = "_4";
- }
+ const char * t0_str = ggml_type_name(op->src[0]->type);
+ const char * t_str = ggml_type_name(op->type);
- snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
- snprintf(name, 256, "%s", base);
+ const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+ const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
+
+ snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
+ snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+ ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
+ ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+ ggml_metal_cv_free(cv);
}
+ res.c4 = is_c4;
+ res.cnt = is_cnt;
+
return res;
}
}
switch (op->op) {
+ case GGML_OP_SCALE:
+ case GGML_OP_FILL:
+ case GGML_OP_CLAMP:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
+ case GGML_OP_LOG:
+ return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
- return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+ return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
}
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
case GGML_OP_REPEAT:
- case GGML_OP_SCALE:
- case GGML_OP_FILL:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
case GGML_OP_CONV_TRANSPOSE_2D:
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
- case GGML_OP_CLAMP:
- return op->src[0]->type == GGML_TYPE_F32;
- case GGML_OP_SQR:
- case GGML_OP_SQRT:
- case GGML_OP_SIN:
- case GGML_OP_COS:
- case GGML_OP_LOG:
- return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_TRI:
#define FC_SSM_CONV 900
#define FC_SOLVE_TRI 1000
#define FC_COUNT_EQUAL 1100
-#define FC_BIN 1200
+#define FC_UNARY 1200
+#define FC_BIN 1300
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
+#define OP_UNARY_NUM_SCALE 10
+#define OP_UNARY_NUM_FILL 11
+#define OP_UNARY_NUM_CLAMP 12
+#define OP_UNARY_NUM_SQR 13
+#define OP_UNARY_NUM_SQRT 14
+#define OP_UNARY_NUM_SIN 15
+#define OP_UNARY_NUM_COS 16
+#define OP_UNARY_NUM_LOG 17
+#define OP_UNARY_NUM_LEAKY_RELU 18
+
+#define OP_UNARY_NUM_TANH 100
+#define OP_UNARY_NUM_RELU 101
+#define OP_UNARY_NUM_SIGMOID 102
+#define OP_UNARY_NUM_GELU 103
+#define OP_UNARY_NUM_GELU_ERF 104
+#define OP_UNARY_NUM_GELU_QUICK 105
+#define OP_UNARY_NUM_SILU 106
+#define OP_UNARY_NUM_ELU 107
+#define OP_UNARY_NUM_NEG 108
+#define OP_UNARY_NUM_ABS 109
+#define OP_UNARY_NUM_SGN 110
+#define OP_UNARY_NUM_STEP 111
+#define OP_UNARY_NUM_HARDSWISH 112
+#define OP_UNARY_NUM_HARDSIGMOID 113
+#define OP_UNARY_NUM_EXP 114
+#define OP_UNARY_NUM_SOFTPLUS 115
+#define OP_UNARY_NUM_EXPM1 116
+
+
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
int32_t dim;
} ggml_metal_kargs_concat;
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+ float slope;
+ float scale;
+ float bias;
+ float val;
+ float min;
+ float max;
+} ggml_metal_kargs_unary;
+
typedef struct {
int32_t ne00;
int32_t ne01;
uint64_t nb3;
} ggml_metal_kargs_repeat;
-typedef struct {
- float scale;
- float bias;
-} ggml_metal_kargs_scale;
-
-typedef struct {
- float val;
-} ggml_metal_kargs_fill;
-
-typedef struct {
- float min;
- float max;
-} ggml_metal_kargs_clamp;
-
typedef struct {
int64_t nk0;
int64_t ne00;
int max_period;
} ggml_metal_kargs_timestep_embedding;
-typedef struct {
- float slope;
-} ggml_metal_kargs_leaky_relu;
-
typedef struct {
int32_t ne00;
int32_t ne01;
n_fuse = ggml_metal_op_acc(ctx, idx);
} break;
case GGML_OP_SCALE:
- {
- n_fuse = ggml_metal_op_scale(ctx, idx);
- } break;
case GGML_OP_FILL:
- {
- n_fuse = ggml_metal_op_fill(ctx, idx);
- } break;
case GGML_OP_CLAMP:
- {
- n_fuse = ggml_metal_op_clamp(ctx, idx);
- } break;
+ case GGML_OP_LEAKY_RELU:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
{
n_fuse = ggml_metal_op_top_k(ctx, idx);
} break;
- case GGML_OP_LEAKY_RELU:
- {
- n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
- } break;
case GGML_OP_TRI:
{
n_fuse = ggml_metal_op_tri(ctx, idx);
return 1;
}
-int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
+int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
- float scale;
- float bias;
- memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
- memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
- ggml_metal_kargs_scale args = {
- /*.scale =*/ scale,
- /*.bias =*/ bias,
- };
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
- int64_t n = ggml_nelements(op);
+ ggml_metal_kargs_unary args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.slope =*/ 0.0,
+ /*.scale =*/ 0.0,
+ /*.bias =*/ 0.0,
+ /*.val =*/ 0.0,
+ /*.min =*/ 0.0,
+ /*.max =*/ 0.0,
+ };
- if (n % 4 == 0) {
- n /= 4;
+ if (op->op == GGML_OP_LEAKY_RELU) {
+ args.slope = ggml_get_op_params_f32(op, 0);
}
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
- ggml_metal_encoder_set_pipeline(enc, pipeline);
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
-
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
- return 1;
-}
-
-int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
- ggml_tensor * op = ctx->node(idx);
-
- ggml_metal_library_t lib = ctx->lib;
- ggml_metal_encoder_t enc = ctx->enc;
-
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
-
- const float val = ggml_get_op_params_f32(op, 0);
-
- ggml_metal_kargs_fill args = {
- /*.val =*/ val
- };
+ if (op->op == GGML_OP_SCALE) {
+ args.scale = ggml_get_op_params_f32(op, 0);
+ args.bias = ggml_get_op_params_f32(op, 1);
+ }
- int64_t n = ggml_nelements(op);
+ if (op->op == GGML_OP_FILL) {
+ args.val = ggml_get_op_params_f32(op, 0);
+ }
- if (n % 4 == 0) {
- n /= 4;
+ if (op->op == GGML_OP_CLAMP) {
+ args.min = ggml_get_op_params_f32(op, 0);
+ args.max = ggml_get_op_params_f32(op, 1);
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
- ggml_metal_encoder_set_pipeline(enc, pipeline);
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
-
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
- return 1;
-}
-
-int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
- ggml_tensor * op = ctx->node(idx);
-
- ggml_metal_library_t lib = ctx->lib;
- ggml_metal_encoder_t enc = ctx->enc;
-
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
-
- float min;
- float max;
- memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
- memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
-
- ggml_metal_kargs_clamp args = {
- /*.min =*/ min,
- /*.max =*/ max,
- };
-
- int64_t n = ggml_nelements(op);
-
- if (n % 4 == 0) {
- n /= 4;
+ if (pipeline.c4) {
+ args.ne00 = ne00/4;
+ args.ne0 = ne0/4;
}
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+ if (pipeline.cnt) {
+ const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
- return 1;
-}
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+ } else {
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
- ggml_tensor * op = ctx->node(idx);
+ const int nth = MIN(args.ne00, nth_max);
- ggml_metal_library_t lib = ctx->lib;
- ggml_metal_encoder_t enc = ctx->enc;
+ const int nk0 = (args.ne00 + nth - 1)/nth;
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
-
- int64_t n = ggml_nelements(op);
-
- if (n % 4 == 0) {
- n /= 4;
+ ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
}
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
- ggml_metal_encoder_set_pipeline(enc, pipeline);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
-
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
return 1;
}
return 1;
}
-int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
- ggml_tensor * op = ctx->node(idx);
-
- ggml_metal_library_t lib = ctx->lib;
- ggml_metal_encoder_t enc = ctx->enc;
-
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
-
- float slope;
- memcpy(&slope, op->op_params, sizeof(float));
-
- ggml_metal_kargs_leaky_relu args = {
- /*.slope =*/ slope
- };
-
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
- int64_t n = ggml_nelements(op);
-
- if (n % 4 == 0) {
- n /= 4;
- }
-
- ggml_metal_encoder_set_pipeline(enc, pipeline);
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
-
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
- return 1;
-}
-
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
GGML_SORT_ORDER_DESC,
};
+constant float GELU_COEF_A = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
+
+// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
+// ref: https://www.johndcook.com/blog/python_erf/
+constant float p_erf = 0.3275911f;
+constant float a1_erf = 0.254829592f;
+constant float a2_erf = -0.284496736f;
+constant float a3_erf = 1.421413741f;
+constant float a4_erf = -1.453152027f;
+constant float a5_erf = 1.061405429f;
+
+template<typename T>
+T erf_approx(T x) {
+ T sign_x = sign(x);
+ x = fabs(x);
+ T t = 1.0f / (1.0f + p_erf * x);
+ T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
+ return sign_x * y;
+}
+
+constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
+constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
+
+template <typename T0, typename T>
+kernel void kernel_unary_impl(
+ constant ggml_metal_kargs_unary & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+#define FC_OP FC_unary_op
+#define FC_CNT FC_unary_cnt
+
+ device const T0 * src0_ptr;
+ device T * dst_ptr;
+
+ int i0;
+
+ if (FC_CNT) {
+ i0 = tgpig.x;
+
+ src0_ptr = (device const T0 *) (src0);
+ dst_ptr = (device T *) (dst);
+ } else {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int k0 = tgpig.x/args.ne01;
+ const int i01 = tgpig.x - k0*args.ne01;
+
+ i0 = k0*ntg.x + tpitg.x;
+
+ src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+ dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
+ }
+
+ {
+ //threadgroup_barrier(mem_flags::mem_none);
+
+ if (!FC_CNT) {
+ if (i0 >= args.ne0) {
+ return;
+ }
+ }
+
+ device const T0 & x = src0_ptr[i0];
+
+ if (FC_OP == OP_UNARY_NUM_SCALE) {
+ dst_ptr[i0] = args.scale * x + args.bias;
+ }
+
+ if (FC_OP == OP_UNARY_NUM_FILL) {
+ dst_ptr[i0] = args.val;
+ }
+
+ if (FC_OP == OP_UNARY_NUM_CLAMP) {
+ dst_ptr[i0] = clamp(x, args.min, args.max);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_SQR) {
+ dst_ptr[i0] = x * x;
+ }
+
+ if (FC_OP == OP_UNARY_NUM_SQRT) {
+ dst_ptr[i0] = sqrt(x);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_SIN) {
+ dst_ptr[i0] = sin(x);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_COS) {
+ dst_ptr[i0] = cos(x);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_LOG) {
+ dst_ptr[i0] = log(x);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
+ dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(x * args.slope);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_TANH) {
+ dst_ptr[i0] = precise::tanh(x);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_RELU) {
+ dst_ptr[i0] = fmax(0.0f, x);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_SIGMOID) {
+ dst_ptr[i0] = 1.0f / (1.0f + exp(-x));
+ }
+
+ if (FC_OP == OP_UNARY_NUM_GELU) {
+ dst_ptr[i0] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+ }
+
+ if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
+ dst_ptr[i0] = 0.5f*x*(1.0f + erf_approx(SQRT_2_INV*x));
+ }
+
+ if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
+ dst_ptr[i0] = x * (1.0f/(1.0f + exp(GELU_QUICK_COEF*x)));
+ }
+
+ if (FC_OP == OP_UNARY_NUM_SILU) {
+ dst_ptr[i0] = x / (1.0f + exp(-x));
+ }
+
+ if (FC_OP == OP_UNARY_NUM_ELU) {
+ dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(exp(x) - 1.0f);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_NEG) {
+ dst_ptr[i0] = -x;
+ }
+
+ if (FC_OP == OP_UNARY_NUM_ABS) {
+ dst_ptr[i0] = fabs(x);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_SGN) {
+ dst_ptr[i0] = T(x > 0.0f) - T(x < 0.0f);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_STEP) {
+ dst_ptr[i0] = T(x > 0.0f);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
+ dst_ptr[i0] = x * fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
+ }
+
+ if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
+ dst_ptr[i0] = fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
+ }
+
+ if (FC_OP == OP_UNARY_NUM_EXP) {
+ dst_ptr[i0] = exp(x);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
+ dst_ptr[i0] = select(log(1.0f + exp(x)), x, x > 20.0f);
+ }
+
+ if (FC_OP == OP_UNARY_NUM_EXPM1) {
+ // TODO: precise implementation
+ dst_ptr[i0] = exp(x) - 1.0f;
+ }
+ }
+
+#undef FC_OP
+#undef FC_CNT
+}
+
+typedef decltype(kernel_unary_impl<float, float>) kernel_unary_t;
+
+template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float>;
+template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4>;
+
+
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
-kernel void kernel_scale_f32(
- constant ggml_metal_kargs_scale & args,
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * args.scale + args.bias;
-}
-
-kernel void kernel_scale_f32_4(
- constant ggml_metal_kargs_scale & args,
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * args.scale + args.bias;
-}
-
-kernel void kernel_fill_f32(
- constant ggml_metal_kargs_fill & args,
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = args.val;
-}
-
-kernel void kernel_fill_f32_4(
- constant ggml_metal_kargs_fill & args,
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = args.val;
-}
-
-kernel void kernel_clamp_f32(
- constant ggml_metal_kargs_clamp & args,
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = clamp(src0[tpig], args.min, args.max);
-}
-
-kernel void kernel_clamp_f32_4(
- constant ggml_metal_kargs_clamp & args,
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = clamp(src0[tpig], args.min, args.max);
-}
-
-kernel void kernel_relu_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_relu_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_sigmoid_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_sigmoid_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_tanh_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = precise::tanh(src0[tpig]);
-}
-
-kernel void kernel_tanh_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = precise::tanh(src0[tpig]);
-}
-
-constant float GELU_COEF_A = 0.044715f;
-constant float GELU_QUICK_COEF = -1.702f;
-constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
-
-kernel void kernel_gelu_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
-
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
-
- // BEWARE !!!
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
- // This was observed with Falcon 7B and 40B models
- //
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_quick_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
-
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-kernel void kernel_gelu_quick_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
-
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
-// ref: https://www.johndcook.com/blog/python_erf/
-constant float p_erf = 0.3275911f;
-constant float a1_erf = 0.254829592f;
-constant float a2_erf = -0.284496736f;
-constant float a3_erf = 1.421413741f;
-constant float a4_erf = -1.453152027f;
-constant float a5_erf = 1.061405429f;
-
-template<typename T>
-T erf_approx(T x) {
- T sign_x = sign(x);
- x = fabs(x);
- T t = 1.0f / (1.0f + p_erf * x);
- T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
- return sign_x * y;
-}
-
-kernel void kernel_gelu_erf_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
-
- dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
-}
-
-kernel void kernel_gelu_erf_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
-
- dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
-}
-
-kernel void kernel_silu_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
- dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_silu_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
- dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_elu_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float x = src0[tpig];
- dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
-}
-
-kernel void kernel_elu_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float4 x = src0[tpig];
- dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
- dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
- dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
- dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
-}
-
-kernel void kernel_sqr_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src0[tpig];
-}
-
-kernel void kernel_sqr_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src0[tpig];
-}
-
-kernel void kernel_sqrt_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sqrt(src0[tpig]);
-}
-
-kernel void kernel_sqrt_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sqrt(src0[tpig]);
-}
-
-kernel void kernel_sin_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sin(src0[tpig]);
-}
-
-kernel void kernel_sin_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sin(src0[tpig]);
-}
-
-kernel void kernel_cos_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = cos(src0[tpig]);
-}
-
-kernel void kernel_cos_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = cos(src0[tpig]);
-}
-
-kernel void kernel_log_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = log(src0[tpig]);
-}
-
-kernel void kernel_log_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = log(src0[tpig]);
-}
-
-kernel void kernel_neg_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = -src0[tpig];
-}
-
-kernel void kernel_neg_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = -src0[tpig];
-}
-
-kernel void kernel_abs_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = fabs(src0[tpig]);
-}
-
-kernel void kernel_abs_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = fabs(src0[tpig]);
-}
-
-kernel void kernel_sgn_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sign(src0[tpig]);
-}
-
-kernel void kernel_sgn_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sign(src0[tpig]);
-}
-
-kernel void kernel_step_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = step(0.0f, src0[tpig]);
-}
-
-kernel void kernel_step_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = step(0.0f, src0[tpig]);
-}
-
-kernel void kernel_hardswish_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float x = src0[tpig];
- dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardswish_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float4 x = src0[tpig];
- dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardsigmoid_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float x = src0[tpig];
- dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardsigmoid_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float4 x = src0[tpig];
- dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_exp_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = exp(src0[tpig]);
-}
-
-kernel void kernel_exp_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = exp(src0[tpig]);
-}
-
-kernel void kernel_softplus_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
- dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
-}
-
-kernel void kernel_softplus_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
- dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
-}
-
-kernel void kernel_expm1_f32(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = exp(src0[tpig]) - 1.0f;
-}
-
-kernel void kernel_expm1_f32_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = exp(src0[tpig]) - 1.0f;
-}
-
kernel void kernel_reglu_f32(
constant ggml_metal_kargs_glu & args,
device const char * src0,
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
-kernel void kernel_leaky_relu_f32(
- constant ggml_metal_kargs_leaky_relu & args,
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float x = src0[tpig];
- dst[tpig] = x > 0.0f ? x : x * args.slope;
-}
-
-kernel void kernel_leaky_relu_f32_4(
- constant ggml_metal_kargs_leaky_relu & args,
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- const float4 x = src0[tpig];
- dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
-}
-
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
template<typename T>
kernel void kernel_memset(
- constant ggml_metal_kargs_fill & args,
+ constant ggml_metal_kargs_memset & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
test_cases.emplace_back(new test_round (type));
test_cases.emplace_back(new test_trunc (type));
test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3}));
+ test_cases.emplace_back(new test_sqr (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3}));
+ test_cases.emplace_back(new test_sqrt (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_log (type, {7, 1, 5, 3}));
+ test_cases.emplace_back(new test_log (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_sin (type, {7, 1, 5, 3}));
+ test_cases.emplace_back(new test_sin (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3}));
+ test_cases.emplace_back(new test_cos (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
+ test_cases.emplace_back(new test_clamp (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
+ test_cases.emplace_back(new test_leaky_relu(type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3}));
- test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 }));
+ test_cases.emplace_back(new test_floor (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3}));
- test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 }));
+ test_cases.emplace_back(new test_ceil (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_round (type, {7, 1, 5, 3}));
- test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 }));
+ test_cases.emplace_back(new test_round (type, {1024, 1024, 1, 1}));
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
- test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 }));
+ test_cases.emplace_back(new test_trunc (type, {1024, 1024, 1, 1}));
}
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));