From: Georgi Gerganov Date: Sat, 13 Sep 2025 13:24:22 +0000 (+0300) Subject: metal : refactor kernel loading (llama/15964) X-Git-Tag: v0.9.1~34 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=8e8d9c3256953f319a15ff36272cdcdc277f05dc;p=pkg%2Fggml%2Fsources%2Fggml metal : refactor kernel loading (llama/15964) * metal : refactor bin kernels loading ggml-ci * metal : refactor rms kernel loading ggml-ci * ci : try to add memory leaks check ggml-ci * ci : try to enable memory leak detection for Mac * cont : seems to be working --- diff --git a/src/ggml-metal/ggml-metal.m b/src/ggml-metal/ggml-metal.m index b8e06aa6..82d8077a 100644 --- a/src/ggml-metal/ggml-metal.m +++ b/src/ggml-metal/ggml-metal.m @@ -232,28 +232,6 @@ struct ggml_metal_kernel { @end enum ggml_metal_kernel_type { - GGML_METAL_KERNEL_TYPE_ADD, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, - GGML_METAL_KERNEL_TYPE_SUB, - GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, - GGML_METAL_KERNEL_TYPE_MUL, - GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, - GGML_METAL_KERNEL_TYPE_DIV, - GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, GGML_METAL_KERNEL_TYPE_ADD_ID, GGML_METAL_KERNEL_TYPE_REPEAT_F32, GGML_METAL_KERNEL_TYPE_REPEAT_F16, @@ -319,9 +297,6 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, - GGML_METAL_KERNEL_TYPE_RMS_NORM, - GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, - GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, GGML_METAL_KERNEL_TYPE_L2_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_NORM, @@ -1177,28 +1152,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); @@ -1264,9 +1217,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); @@ -1722,6 +1672,73 @@ static id ggml_metal_get_pipeline_flash_attn_ext_vec_re GGML_UNUSED(op); } +static id ggml_metal_get_pipeline_bin( + ggml_backend_t backend, enum ggml_op op, + int32_t n_fuse, + bool row) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + const char * op_str = "undefined"; + switch (op) { + case GGML_OP_ADD: op_str = "add"; break; + case GGML_OP_SUB: op_str = "sub"; break; + case GGML_OP_MUL: op_str = "mul"; break; + case GGML_OP_DIV: op_str = "div"; break; + default: GGML_ABORT("fatal error"); + }; + + if (row) { + snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); + } else { + snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); + } + + snprintf(name, 256, "%s", base); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + return ggml_metal_compile_kernel(backend, base, name, nil); + } +} + +static id ggml_metal_get_pipeline_rms_norm( + ggml_backend_t backend, struct ggml_tensor * op, + int32_t n_fuse) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + switch (n_fuse) { + case 1: snprintf(base, 256, "kernel_rms_norm"); break; + case 2: snprintf(base, 256, "kernel_rms_norm_mul"); break; + case 3: snprintf(base, 256, "kernel_rms_norm_mul_add"); break; + default: GGML_ABORT("fatal error"); + } + + snprintf(name, 256, "%s", base); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + return ggml_metal_compile_kernel(backend, base, name, nil); + } + + GGML_UNUSED(op); +} + static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { GGML_LOG_INFO("%s: deallocating\n", __func__); @@ -2359,8 +2376,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in bool bcast_row = false; - id pipeline = nil; - ggml_metal_kargs_bin args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -2441,55 +2456,19 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in } } + id pipeline = nil; + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { GGML_ASSERT(ggml_is_contiguous(src0)); // src1 is a row GGML_ASSERT(ne11 == 1); - switch (dst->op) { - case GGML_OP_ADD: - { - switch (n_fuse) { - case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break; - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break; - case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break; - case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break; - case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break; - default: GGML_ABORT("fatal error"); - } + pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, true); bcast_row = true; } else { - switch (dst->op) { - case GGML_OP_ADD: - { - switch (n_fuse) { - case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break; - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break; - case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break; - case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break; - case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ABORT("fatal error"); - } + pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, false); } if (n_fuse > 1) { @@ -2650,8 +2629,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in ggml_metal_encode_concurrency_reset(ctx_enc); } - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; - ggml_metal_kargs_bin args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -2681,6 +2658,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in /*.o1 =*/ { offs_src1}, }; + //const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + const id pipeline = ggml_metal_get_pipeline_bin(backend, GGML_OP_ADD, 1, false); + [encoder setComputePipelineState:pipeline]; [encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; @@ -4659,14 +4639,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in } } - id pipeline; - - switch (n_fuse) { - case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break; - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break; - default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse); - } + const id pipeline = ggml_metal_get_pipeline_rms_norm(backend, node, n_fuse); int nth = 32; // SIMD width diff --git a/src/ggml-metal/ggml-metal.metal b/src/ggml-metal/ggml-metal.metal index 157d0cc6..4314c9cc 100644 --- a/src/ggml-metal/ggml-metal.metal +++ b/src/ggml-metal/ggml-metal.metal @@ -928,7 +928,7 @@ kernel void kernel_add_fuse_impl( typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; -template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; +template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; @@ -937,7 +937,7 @@ template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_ template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; -kernel void kernel_sub( +kernel void kernel_sub_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -963,7 +963,7 @@ kernel void kernel_sub( } } -kernel void kernel_mul( +kernel void kernel_mul_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -996,7 +996,7 @@ kernel void kernel_mul( } } -kernel void kernel_div( +kernel void kernel_div_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -1096,23 +1096,17 @@ kernel void kernel_add_row_c4_fuse_impl( device const char * src1, device char * dst, uint tpig[[thread_position_in_grid]]) { - const uint nb = args.ne00/4; const uint i = tpig % nb; device const float4 * src0_row = (device const float4 *) (src0); device float4 * dst_row = (device float4 *) (dst); - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - float4 res = src0_row[tpig]; #pragma unroll(F) for (short j = 0; j < F; ++j) { - res += src1_row[j][i]; + res += ((device const float4 *) (src1 + args.o1[j]))[i]; } dst_row[tpig] = res; @@ -1120,7 +1114,7 @@ kernel void kernel_add_row_c4_fuse_impl( typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; -template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; +template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; @@ -1160,7 +1154,7 @@ kernel void kernel_sub_row_c4_fuse_impl( typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; -template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; +template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; template kernel void kernel_mul_row_c4_fuse_impl( @@ -1193,7 +1187,7 @@ kernel void kernel_mul_row_c4_fuse_impl( typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; -template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; +template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; template kernel void kernel_div_row_c4_fuse_impl( @@ -1226,7 +1220,7 @@ kernel void kernel_div_row_c4_fuse_impl( typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; -template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; +template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; kernel void kernel_scale( device const float * src0,