GGML_UNUSED(op);
}
-ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
- ggml_metal_library_t lib,
- ggml_op op,
- int32_t n_fuse,
- bool row) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
char base[256];
char name[256];
- const char * op_str = "undefined";
- switch (op) {
- case GGML_OP_ADD: op_str = "add"; break;
- case GGML_OP_SUB: op_str = "sub"; break;
- case GGML_OP_MUL: op_str = "mul"; break;
- case GGML_OP_DIV: op_str = "div"; break;
+ int op_num = -1;
+
+ switch (op->op) {
+ case GGML_OP_ADD: op_num = 0; break;
+ case GGML_OP_SUB: op_num = 1; break;
+ case GGML_OP_MUL: op_num = 2; break;
+ case GGML_OP_DIV: op_num = 3; break;
default: GGML_ABORT("fatal error");
};
- if (row) {
- snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
- } else {
- snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
+ const char * t0_str = ggml_type_name(op->src[0]->type);
+ const char * t1_str = ggml_type_name(op->src[1]->type);
+ const char * t_str = ggml_type_name(op->type);
+
+ const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
+
+ const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
+
+ snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
+ snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+ ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
+ ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
+ ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+ ggml_metal_cv_free(cv);
}
- snprintf(name, 256, "%s", base);
+ res.c4 = is_c4;
+ res.cnt = is_rb;
+
+ return res;
+}
+
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
+ char base[256];
+ char name[256];
+
+ int op_num = -1;
+
+ switch (op) {
+ case GGML_OP_ADD: op_num = 0; break;
+ case GGML_OP_SUB: op_num = 1; break;
+ case GGML_OP_MUL: op_num = 2; break;
+ case GGML_OP_DIV: op_num = 3; break;
+ default: GGML_ABORT("fatal error");
+ };
+
+ snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
+ snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+ ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
+ ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1);
+ ggml_metal_cv_set_bool (cv, false, FC_BIN + 2);
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+ ggml_metal_cv_free(cv);
}
return res;
GGML_SORT_ORDER_DESC,
};
-// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across all dims
-// cons: not very efficient
-template <int F>
-kernel void kernel_add_fuse_impl(
+// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
+constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
+constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
+constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
+
+template <typename T0, typename T1, typename T>
+kernel void kernel_bin_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig.z;
- const int i02 = tgpig.y;
- const int i01 = tgpig.x;
+#define FC_OP FC_bin_op
+#define FC_F FC_bin_f
+#define FC_RB FC_bin_rb
- const int i13 = i03%args.ne13;
- const int i12 = i02%args.ne12;
- const int i11 = i01%args.ne11;
+ if (FC_RB) {
+ // row broadcast
+ const uint i0 = tgpig.x;
+ const uint i1 = i0%args.ne10;
- device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
- device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
+ device const T0 * src0_row = (device const T0 *) (src0);
+ device T * dst_row = (device T *) (dst);
- device const float * src1_ptr[F];
- for (short j = 0; j < F; ++j) {
- src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
- }
+ if (FC_F == 1) {
+ device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
+ if (FC_OP == 0) {
+ dst_row[i0] = src0_row[i0] + src1_row[i1];
+ }
+
+ if (FC_OP == 1) {
+ dst_row[i0] = src0_row[i0] - src1_row[i1];
+ }
+
+ if (FC_OP == 2) {
+ dst_row[i0] = src0_row[i0] * src1_row[i1];
+ }
+
+ if (FC_OP == 3) {
+ dst_row[i0] = src0_row[i0] / src1_row[i1];
+ }
+ } else {
+ T0 res = src0_row[i0];
+
+ if (FC_OP == 0) {
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+ res += ((device const T1 *) (src1 + args.o1[j]))[i1];
+ }
+ }
+
+ if (FC_OP == 1) {
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+ res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
+ }
+ }
- float res = src0_ptr[i0];
+ if (FC_OP == 2) {
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+ res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
+ }
+ }
+
+ if (FC_OP == 3) {
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+ res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
+ }
+ }
-#pragma unroll
- for (short j = 0; j < F; ++j) {
- res += src1_ptr[j][i10];
+ dst_row[i0] = res;
}
+ } else {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int i01 = tgpig.x;
- dst_ptr[i0] = res;
- }
-}
+ if (i01 >= args.ne01) {
+ return;
+ }
-typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
+ const int i13 = i03%args.ne13;
+ const int i12 = i02%args.ne12;
+ const int i11 = i01%args.ne11;
-template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
-template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
-template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
-template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
-template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
-template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
-template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
-template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
+ device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
+ device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
-kernel void kernel_sub_fuse_1(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig.z;
- const int i02 = tgpig.y;
- const int i01 = tgpig.x;
+ if (FC_F == 1) {
+ device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
- const int i13 = i03%args.ne13;
- const int i12 = i02%args.ne12;
- const int i11 = i01%args.ne11;
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
+ if (FC_OP == 0) {
+ dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
+ }
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
- }
-}
+ if (FC_OP == 1) {
+ dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
+ }
-kernel void kernel_mul_fuse_1(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig.z;
- const int i02 = tgpig.y;
- const int i01 = tgpig.x;
+ if (FC_OP == 2) {
+ dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
+ }
- const int i13 = i03%args.ne13;
- const int i12 = i02%args.ne12;
- const int i11 = i01%args.ne11;
+ if (FC_OP == 3) {
+ dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
+ }
+ }
+ } else {
+ device const T1 * src1_ptr[8];
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+ src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+ }
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
- if (args.ne10 == 1) {
- const float x = *((device float *)(src1_ptr));
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
- }
- } else {
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
- }
- }
-}
+ T res = src0_ptr[i0];
-kernel void kernel_div_fuse_1(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i03 = tgpig.z;
- const int i02 = tgpig.y;
- const int i01 = tgpig.x;
+ if (FC_OP == 0) {
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+ res += src1_ptr[j][i10];
+ }
+ }
+
+ if (FC_OP == 1) {
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+ res -= src1_ptr[j][i10];
+ }
+ }
- const int i13 = i03%args.ne13;
- const int i12 = i02%args.ne12;
- const int i11 = i01%args.ne11;
+ if (FC_OP == 2) {
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+ res *= src1_ptr[j][i10];
+ }
+ }
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
+ if (FC_OP == 3) {
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+ res /= src1_ptr[j][i10];
+ }
+ }
- if (args.ne10 == 1) {
- const float x = 1.0f / *((device float *)(src1_ptr));
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
- }
- } else {
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+ dst_ptr[i0] = res;
+ }
}
}
+
+#undef FC_OP
+#undef FC_F
+#undef FC_RB
}
+typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
+
+template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
+template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
+
kernel void kernel_add_id(
constant ggml_metal_kargs_add_id & args,
device const char * src0,
const size_t nb1 = args.ne0 * sizeof(float);
const size_t nb2 = args.ne1 * nb1;
- device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
+ device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
-// assumption: src1 is a row
-// broadcast src1 into src0
-template <short F>
-kernel void kernel_add_row_c4_fuse_impl(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tpig[[thread_position_in_grid]]) {
- const uint nb = args.ne00/4;
- const uint i = tpig % nb;
-
- device const float4 * src0_row = (device const float4 *) (src0);
- device float4 * dst_row = (device float4 *) (dst);
-
- float4 res = src0_row[tpig];
-
-#pragma unroll(F)
- for (short j = 0; j < F; ++j) {
- res += ((device const float4 *) (src1 + args.o1[j]))[i];
- }
-
- dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
-
-template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
-template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
-template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
-template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
-template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
-template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
-template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
-template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
-
-template <short F>
-kernel void kernel_sub_row_c4_fuse_impl(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tpig[[thread_position_in_grid]]) {
-
- const uint nb = args.ne00/4;
- const uint i = tpig % nb;
-
- device const float4 * src0_row = (device const float4 *) (src0);
- device float4 * dst_row = (device float4 *) (dst);
-
- device const float4 * src1_row[F];
- for (short j = 0; j < F; ++j) {
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
- }
-
- float4 res = src0_row[tpig];
-
-#pragma unroll(F)
- for (short j = 0; j < F; ++j) {
- res -= src1_row[j][i];
- }
-
- dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
-
-template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
-
-template <short F>
-kernel void kernel_mul_row_c4_fuse_impl(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tpig[[thread_position_in_grid]]) {
-
- const uint nb = args.ne00/4;
- const uint i = tpig % nb;
-
- device const float4 * src0_row = (device const float4 *) (src0);
- device float4 * dst_row = (device float4 *) (dst);
-
- device const float4 * src1_row[F];
- for (short j = 0; j < F; ++j) {
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
- }
-
- float4 res = src0_row[tpig];
-
-#pragma unroll(F)
- for (short j = 0; j < F; ++j) {
- res *= src1_row[j][i];
- }
-
- dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
-
-template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
-
-template <short F>
-kernel void kernel_div_row_c4_fuse_impl(
- constant ggml_metal_kargs_bin & args,
- device const char * src0,
- device const char * src1,
- device char * dst,
- uint tpig[[thread_position_in_grid]]) {
-
- const uint nb = args.ne00/4;
- const uint i = tpig % nb;
-
- device const float4 * src0_row = (device const float4 *) (src0);
- device float4 * dst_row = (device float4 *) (dst);
-
- device const float4 * src1_row[F];
- for (short j = 0; j < F; ++j) {
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
- }
-
- float4 res = src0_row[tpig];
-
-#pragma unroll(F)
- for (short j = 0; j < F; ++j) {
- res /= src1_row[j][i];
- }
-
- dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
-
-template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
-
kernel void kernel_scale_f32(
constant ggml_metal_kargs_scale & args,
device const float * src0,