@end
enum ggml_metal_kernel_type {
- GGML_METAL_KERNEL_TYPE_ADD,
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
- GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
- GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
- GGML_METAL_KERNEL_TYPE_SUB,
- GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
- GGML_METAL_KERNEL_TYPE_MUL,
- GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
- GGML_METAL_KERNEL_TYPE_DIV,
- GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
GGML_METAL_KERNEL_TYPE_ADD_ID,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
- GGML_METAL_KERNEL_TYPE_RMS_NORM,
- GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
- GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
GGML_METAL_KERNEL_TYPE_L2_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
GGML_METAL_KERNEL_TYPE_NORM,
// simd_sum and simd_max requires MTLGPUFamilyApple7
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_UNUSED(op);
}
+static id<MTLComputePipelineState> ggml_metal_get_pipeline_bin(
+ ggml_backend_t backend, enum ggml_op op,
+ int32_t n_fuse,
+ bool row) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+
+ char base[256];
+ char name[256];
+
+ @autoreleasepool {
+ const char * op_str = "undefined";
+ switch (op) {
+ case GGML_OP_ADD: op_str = "add"; break;
+ case GGML_OP_SUB: op_str = "sub"; break;
+ case GGML_OP_MUL: op_str = "mul"; break;
+ case GGML_OP_DIV: op_str = "div"; break;
+ default: GGML_ABORT("fatal error");
+ };
+
+ if (row) {
+ snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
+ } else {
+ snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
+ }
+
+ snprintf(name, 256, "%s", base);
+
+ id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
+ if (res) {
+ // kernel found
+ return res;
+ }
+
+ return ggml_metal_compile_kernel(backend, base, name, nil);
+ }
+}
+
+static id<MTLComputePipelineState> ggml_metal_get_pipeline_rms_norm(
+ ggml_backend_t backend, struct ggml_tensor * op,
+ int32_t n_fuse) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+
+ char base[256];
+ char name[256];
+
+ @autoreleasepool {
+ switch (n_fuse) {
+ case 1: snprintf(base, 256, "kernel_rms_norm"); break;
+ case 2: snprintf(base, 256, "kernel_rms_norm_mul"); break;
+ case 3: snprintf(base, 256, "kernel_rms_norm_mul_add"); break;
+ default: GGML_ABORT("fatal error");
+ }
+
+ snprintf(name, 256, "%s", base);
+
+ id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
+ if (res) {
+ // kernel found
+ return res;
+ }
+
+ return ggml_metal_compile_kernel(backend, base, name, nil);
+ }
+
+ GGML_UNUSED(op);
+}
+
static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
GGML_LOG_INFO("%s: deallocating\n", __func__);
bool bcast_row = false;
- id<MTLComputePipelineState> pipeline = nil;
-
ggml_metal_kargs_bin args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
}
}
+ id<MTLComputePipelineState> pipeline = nil;
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));
// src1 is a row
GGML_ASSERT(ne11 == 1);
- switch (dst->op) {
- case GGML_OP_ADD:
- {
- switch (n_fuse) {
- case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
- case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
- case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
- case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
- case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
- case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
- case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
- case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
- } break;
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
+ pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, true);
bcast_row = true;
} else {
- switch (dst->op) {
- case GGML_OP_ADD:
- {
- switch (n_fuse) {
- case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
- case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
- case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
- case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
- case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
- case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
- case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
- case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
- } break;
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
+ pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, false);
}
if (n_fuse > 1) {
ggml_metal_encode_concurrency_reset(ctx_enc);
}
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
-
ggml_metal_kargs_bin args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.o1 =*/ { offs_src1},
};
+ //const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
+ const id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_bin(backend, GGML_OP_ADD, 1, false);
+
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
}
}
- id<MTLComputePipelineState> pipeline;
-
- switch (n_fuse) {
- case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
- case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
- case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
- default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
- }
+ const id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_rms_norm(backend, node, n_fuse);
int nth = 32; // SIMD width
typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
-template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
+template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
-kernel void kernel_sub(
+kernel void kernel_sub_fuse_1(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
}
}
-kernel void kernel_mul(
+kernel void kernel_mul_fuse_1(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
}
}
-kernel void kernel_div(
+kernel void kernel_div_fuse_1(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device const char * src1,
device char * dst,
uint tpig[[thread_position_in_grid]]) {
-
const uint nb = args.ne00/4;
const uint i = tpig % nb;
device const float4 * src0_row = (device const float4 *) (src0);
device float4 * dst_row = (device float4 *) (dst);
- device const float4 * src1_row[F];
- for (short j = 0; j < F; ++j) {
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
- }
-
float4 res = src0_row[tpig];
#pragma unroll(F)
for (short j = 0; j < F; ++j) {
- res += src1_row[j][i];
+ res += ((device const float4 *) (src1 + args.o1[j]))[i];
}
dst_row[tpig] = res;
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
-template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
+template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
-template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
+template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
template <short F>
kernel void kernel_mul_row_c4_fuse_impl(
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
-template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
+template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
template <short F>
kernel void kernel_div_row_c4_fuse_impl(
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
-template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
+template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
kernel void kernel_scale(
device const float * src0,