#define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
+#define GGML_METAL_MAX_KERNELS 256
+
struct ggml_metal_buffer {
const char * name;
id<MTLBuffer> metal;
};
+struct ggml_metal_kernel {
+ id<MTLFunction> function;
+ id<MTLComputePipelineState> pipeline;
+};
+
+enum ggml_metal_kernel_type {
+ GGML_METAL_KERNEL_TYPE_ADD,
+ GGML_METAL_KERNEL_TYPE_ADD_ROW,
+ GGML_METAL_KERNEL_TYPE_MUL,
+ GGML_METAL_KERNEL_TYPE_MUL_ROW,
+ GGML_METAL_KERNEL_TYPE_DIV,
+ GGML_METAL_KERNEL_TYPE_DIV_ROW,
+ GGML_METAL_KERNEL_TYPE_SCALE,
+ GGML_METAL_KERNEL_TYPE_SCALE_4,
+ GGML_METAL_KERNEL_TYPE_TANH,
+ GGML_METAL_KERNEL_TYPE_RELU,
+ GGML_METAL_KERNEL_TYPE_GELU,
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK,
+ GGML_METAL_KERNEL_TYPE_SILU,
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX,
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
+ GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
+ GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
+ GGML_METAL_KERNEL_TYPE_RMS_NORM,
+ GGML_METAL_KERNEL_TYPE_GROUP_NORM,
+ GGML_METAL_KERNEL_TYPE_NORM,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
+ GGML_METAL_KERNEL_TYPE_ROPE_F32,
+ GGML_METAL_KERNEL_TYPE_ROPE_F16,
+ GGML_METAL_KERNEL_TYPE_ALIBI_F32,
+ GGML_METAL_KERNEL_TYPE_IM2COL_F16,
+ GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
+ GGML_METAL_KERNEL_TYPE_PAD_F32,
+ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
+ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
+ GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
+ //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
+ //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
+ GGML_METAL_KERNEL_TYPE_CONCAT,
+ GGML_METAL_KERNEL_TYPE_SQR,
+ GGML_METAL_KERNEL_TYPE_SUM_ROWS,
+
+ GGML_METAL_KERNEL_TYPE_COUNT
+};
+
struct ggml_metal_context {
int n_cb;
int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
+ struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
+
int concur_list[GGML_MAX_CONCUR];
int concur_list_len;
- // custom kernels
-#define GGML_METAL_DECL_KERNEL(name) \
- id<MTLFunction> function_##name; \
- id<MTLComputePipelineState> pipeline_##name
-
- GGML_METAL_DECL_KERNEL(add);
- GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
- GGML_METAL_DECL_KERNEL(mul);
- GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
- GGML_METAL_DECL_KERNEL(div);
- GGML_METAL_DECL_KERNEL(div_row);
- GGML_METAL_DECL_KERNEL(scale);
- GGML_METAL_DECL_KERNEL(scale_4);
- GGML_METAL_DECL_KERNEL(tanh);
- GGML_METAL_DECL_KERNEL(relu);
- GGML_METAL_DECL_KERNEL(gelu);
- GGML_METAL_DECL_KERNEL(gelu_quick);
- GGML_METAL_DECL_KERNEL(silu);
- GGML_METAL_DECL_KERNEL(soft_max);
- GGML_METAL_DECL_KERNEL(soft_max_4);
- GGML_METAL_DECL_KERNEL(diag_mask_inf);
- GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
- GGML_METAL_DECL_KERNEL(get_rows_f32);
- GGML_METAL_DECL_KERNEL(get_rows_f16);
- GGML_METAL_DECL_KERNEL(get_rows_q4_0);
- GGML_METAL_DECL_KERNEL(get_rows_q4_1);
- GGML_METAL_DECL_KERNEL(get_rows_q5_0);
- GGML_METAL_DECL_KERNEL(get_rows_q5_1);
- GGML_METAL_DECL_KERNEL(get_rows_q8_0);
- GGML_METAL_DECL_KERNEL(get_rows_q2_K);
- GGML_METAL_DECL_KERNEL(get_rows_q3_K);
- GGML_METAL_DECL_KERNEL(get_rows_q4_K);
- GGML_METAL_DECL_KERNEL(get_rows_q5_K);
- GGML_METAL_DECL_KERNEL(get_rows_q6_K);
- GGML_METAL_DECL_KERNEL(get_rows_i32);
- GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
- GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
- GGML_METAL_DECL_KERNEL(rms_norm);
- GGML_METAL_DECL_KERNEL(group_norm);
- GGML_METAL_DECL_KERNEL(norm);
- GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
- GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
- GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
- GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
- //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
- GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
- GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
- GGML_METAL_DECL_KERNEL(rope_f32);
- GGML_METAL_DECL_KERNEL(rope_f16);
- GGML_METAL_DECL_KERNEL(alibi_f32);
- GGML_METAL_DECL_KERNEL(im2col_f16);
- GGML_METAL_DECL_KERNEL(upscale_f32);
- GGML_METAL_DECL_KERNEL(pad_f32);
- GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
- GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
- GGML_METAL_DECL_KERNEL(leaky_relu_f32);
- GGML_METAL_DECL_KERNEL(cpy_f32_f16);
- GGML_METAL_DECL_KERNEL(cpy_f32_f32);
- GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
- GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
- GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
- //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
- //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
- GGML_METAL_DECL_KERNEL(cpy_f16_f16);
- GGML_METAL_DECL_KERNEL(cpy_f16_f32);
- GGML_METAL_DECL_KERNEL(concat);
- GGML_METAL_DECL_KERNEL(sqr);
- GGML_METAL_DECL_KERNEL(sum_rows);
-
-#undef GGML_METAL_DECL_KERNEL
+ bool support_simdgroup_reduction;
+ bool support_simdgroup_mm;
};
// MSL code
return NULL;
}
- MTLCompileOptions* options = nil;
+ // dictionary of preprocessor macros
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
+
#ifdef GGML_QKK_64
- options = [MTLCompileOptions new];
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
+ prep[@"QK_K"] = @(64);
#endif
- // try to disable fast-math
- // NOTE: this seems to have no effect whatsoever
- // instead, in order to disable fast-math, we have to build default.metallib from the command line
- // using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
- // and go through the "pre-compiled library found" path above
+
+ MTLCompileOptions* options = [MTLCompileOptions new];
+ options.preprocessorMacros = prep;
+
//[options setFastMathEnabled:false];
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
+
+ [options release];
+ [prep release];
}
if (error) {
// print MTL GPU family:
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
+ const NSInteger MTLGPUFamilyMetal3 = 5001;
+
// determine max supported GPU family
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
- if ([ctx->device supportsFamily:i]) {
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
- break;
+ {
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
+ if ([ctx->device supportsFamily:i]) {
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
+ break;
+ }
+ }
+
+ for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
+ if ([ctx->device supportsFamily:i]) {
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
+ break;
+ }
+ }
+
+ for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
+ if ([ctx->device supportsFamily:i]) {
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
+ break;
+ }
}
}
+ ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
+ ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
+
+ ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
+
+ GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
+ GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
if (ctx->device.maxTransferRate != 0) {
{
NSError * error = nil;
+ for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
+ ctx->kernels[i].function = nil;
+ ctx->kernels[i].pipeline = nil;
+ }
+
/*
- GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
- (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
- (int) ctx->pipeline_##name.threadExecutionWidth); \
+ GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
+ (int) kernel->pipeline.threadExecutionWidth); \
*/
-#define GGML_METAL_ADD_KERNEL(name) \
- ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
- if (error) { \
- GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
- return NULL; \
+#define GGML_METAL_ADD_KERNEL(e, name, supported) \
+ if (supported) { \
+ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
+ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
+ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
+ GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
+ (int) kernel->pipeline.threadExecutionWidth); \
+ if (error) { \
+ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
+ return NULL; \
+ } \
+ } else { \
+ GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
}
- GGML_METAL_ADD_KERNEL(add);
- GGML_METAL_ADD_KERNEL(add_row);
- GGML_METAL_ADD_KERNEL(mul);
- GGML_METAL_ADD_KERNEL(mul_row);
- GGML_METAL_ADD_KERNEL(div);
- GGML_METAL_ADD_KERNEL(div_row);
- GGML_METAL_ADD_KERNEL(scale);
- GGML_METAL_ADD_KERNEL(scale_4);
- GGML_METAL_ADD_KERNEL(tanh);
- GGML_METAL_ADD_KERNEL(relu);
- GGML_METAL_ADD_KERNEL(gelu);
- GGML_METAL_ADD_KERNEL(gelu_quick);
- GGML_METAL_ADD_KERNEL(silu);
- GGML_METAL_ADD_KERNEL(soft_max);
- GGML_METAL_ADD_KERNEL(soft_max_4);
- GGML_METAL_ADD_KERNEL(diag_mask_inf);
- GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
- GGML_METAL_ADD_KERNEL(get_rows_f32);
- GGML_METAL_ADD_KERNEL(get_rows_f16);
- GGML_METAL_ADD_KERNEL(get_rows_q4_0);
- GGML_METAL_ADD_KERNEL(get_rows_q4_1);
- GGML_METAL_ADD_KERNEL(get_rows_q5_0);
- GGML_METAL_ADD_KERNEL(get_rows_q5_1);
- GGML_METAL_ADD_KERNEL(get_rows_q8_0);
- GGML_METAL_ADD_KERNEL(get_rows_q2_K);
- GGML_METAL_ADD_KERNEL(get_rows_q3_K);
- GGML_METAL_ADD_KERNEL(get_rows_q4_K);
- GGML_METAL_ADD_KERNEL(get_rows_q5_K);
- GGML_METAL_ADD_KERNEL(get_rows_q6_K);
- GGML_METAL_ADD_KERNEL(get_rows_i32);
- GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
- GGML_METAL_ADD_KERNEL(get_rows_iq2_xs);
- GGML_METAL_ADD_KERNEL(rms_norm);
- GGML_METAL_ADD_KERNEL(group_norm);
- GGML_METAL_ADD_KERNEL(norm);
- GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
- GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
- GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
- GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
- //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
- GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32);
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
- GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
- GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32);
- }
- GGML_METAL_ADD_KERNEL(rope_f32);
- GGML_METAL_ADD_KERNEL(rope_f16);
- GGML_METAL_ADD_KERNEL(alibi_f32);
- GGML_METAL_ADD_KERNEL(im2col_f16);
- GGML_METAL_ADD_KERNEL(upscale_f32);
- GGML_METAL_ADD_KERNEL(pad_f32);
- GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
- GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
- GGML_METAL_ADD_KERNEL(leaky_relu_f32);
- GGML_METAL_ADD_KERNEL(cpy_f32_f16);
- GGML_METAL_ADD_KERNEL(cpy_f32_f32);
- GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
- GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
- GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
- //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
- //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
- GGML_METAL_ADD_KERNEL(cpy_f16_f16);
- GGML_METAL_ADD_KERNEL(cpy_f16_f32);
- GGML_METAL_ADD_KERNEL(concat);
- GGML_METAL_ADD_KERNEL(sqr);
- GGML_METAL_ADD_KERNEL(sum_rows);
-
-#undef GGML_METAL_ADD_KERNEL
+ // simd_sum and simd_max requires MTLGPUFamilyApple7
+
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
}
return ctx;
void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
-#define GGML_METAL_DEL_KERNEL(name) \
- [ctx->function_##name release]; \
- [ctx->pipeline_##name release];
-
- GGML_METAL_DEL_KERNEL(add);
- GGML_METAL_DEL_KERNEL(add_row);
- GGML_METAL_DEL_KERNEL(mul);
- GGML_METAL_DEL_KERNEL(mul_row);
- GGML_METAL_DEL_KERNEL(div);
- GGML_METAL_DEL_KERNEL(div_row);
- GGML_METAL_DEL_KERNEL(scale);
- GGML_METAL_DEL_KERNEL(scale_4);
- GGML_METAL_DEL_KERNEL(tanh);
- GGML_METAL_DEL_KERNEL(relu);
- GGML_METAL_DEL_KERNEL(gelu);
- GGML_METAL_DEL_KERNEL(gelu_quick);
- GGML_METAL_DEL_KERNEL(silu);
- GGML_METAL_DEL_KERNEL(soft_max);
- GGML_METAL_DEL_KERNEL(soft_max_4);
- GGML_METAL_DEL_KERNEL(diag_mask_inf);
- GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
- GGML_METAL_DEL_KERNEL(get_rows_f32);
- GGML_METAL_DEL_KERNEL(get_rows_f16);
- GGML_METAL_DEL_KERNEL(get_rows_q4_0);
- GGML_METAL_DEL_KERNEL(get_rows_q4_1);
- GGML_METAL_DEL_KERNEL(get_rows_q5_0);
- GGML_METAL_DEL_KERNEL(get_rows_q5_1);
- GGML_METAL_DEL_KERNEL(get_rows_q8_0);
- GGML_METAL_DEL_KERNEL(get_rows_q2_K);
- GGML_METAL_DEL_KERNEL(get_rows_q3_K);
- GGML_METAL_DEL_KERNEL(get_rows_q4_K);
- GGML_METAL_DEL_KERNEL(get_rows_q5_K);
- GGML_METAL_DEL_KERNEL(get_rows_q6_K);
- GGML_METAL_DEL_KERNEL(get_rows_i32);
- GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
- GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
- GGML_METAL_DEL_KERNEL(rms_norm);
- GGML_METAL_DEL_KERNEL(group_norm);
- GGML_METAL_DEL_KERNEL(norm);
- GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
- GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
- GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
- GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
- //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
- GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
- GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
- GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
- }
- GGML_METAL_DEL_KERNEL(rope_f32);
- GGML_METAL_DEL_KERNEL(rope_f16);
- GGML_METAL_DEL_KERNEL(alibi_f32);
- GGML_METAL_DEL_KERNEL(im2col_f16);
- GGML_METAL_DEL_KERNEL(upscale_f32);
- GGML_METAL_DEL_KERNEL(pad_f32);
- GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
- GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
- GGML_METAL_DEL_KERNEL(leaky_relu_f32);
- GGML_METAL_DEL_KERNEL(cpy_f32_f16);
- GGML_METAL_DEL_KERNEL(cpy_f32_f32);
- GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
- GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
- GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
- //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
- //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
- GGML_METAL_DEL_KERNEL(cpy_f16_f16);
- GGML_METAL_DEL_KERNEL(cpy_f16_f32);
- GGML_METAL_DEL_KERNEL(concat);
- GGML_METAL_DEL_KERNEL(sqr);
- GGML_METAL_DEL_KERNEL(sum_rows);
-
-#undef GGML_METAL_DEL_KERNEL
for (int i = 0; i < ctx->n_buffers; ++i) {
[ctx->buffers[i].metal release];
}
+ for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
+ if (ctx->kernels[i].pipeline) {
+ [ctx->kernels[i].pipeline release];
+ }
+
+ if (ctx->kernels[i].function) {
+ [ctx->kernels[i].function release];
+ }
+ }
+
[ctx->library release];
[ctx->queue release];
[ctx->device release];
}
}
-static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
+static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SUM_ROWS:
+ return true;
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
+ return ctx->support_simdgroup_reduction;
case GGML_OP_NORM:
case GGML_OP_ALIBI:
case GGML_OP_ROPE:
case GGML_OP_PAD:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
+ return true;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
- return true;
+ return ctx->support_simdgroup_reduction;
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:
return false;
}
}
+
bool ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
} break;
}
- if (!ggml_metal_supports_op(dst)) {
+ if (!ggml_metal_supports_op(ctx, dst)) {
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
GGML_ASSERT(!"unsupported op");
}
{
const int64_t nb = ne00;
- [encoder setComputePipelineState:ctx->pipeline_concat];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
nb = ne00 / 4;
switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
- case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
- case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
default: GGML_ASSERT(false);
}
bcast_row = true;
} else {
switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
- case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
- case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
default: GGML_ASSERT(false);
}
}
// not sure how to avoid this
// TODO: make a simpler cpy_bytes kernel
- const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
- [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
- [encoder setComputePipelineState:ctx->pipeline_add];
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
- const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00);
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
int64_t n = ggml_nelements(dst);
+ id<MTLComputePipelineState> pipeline = nil;
+
if (n % 4 == 0) {
n /= 4;
- [encoder setComputePipelineState:ctx->pipeline_scale_4];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
} else {
- [encoder setComputePipelineState:ctx->pipeline_scale];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
}
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_TANH:
{
- [encoder setComputePipelineState:ctx->pipeline_tanh];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
} break;
case GGML_UNARY_OP_RELU:
{
- [encoder setComputePipelineState:ctx->pipeline_relu];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
} break;
case GGML_UNARY_OP_GELU:
{
- [encoder setComputePipelineState:ctx->pipeline_gelu];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
} break;
case GGML_UNARY_OP_GELU_QUICK:
{
- [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
} break;
case GGML_UNARY_OP_SILU:
{
- [encoder setComputePipelineState:ctx->pipeline_silu];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
{
GGML_ASSERT(ggml_is_contiguous(src0));
- [encoder setComputePipelineState:ctx->pipeline_sqr];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SUM_ROWS:
{
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
- [encoder setComputePipelineState:ctx->pipeline_sum_rows];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
{
int nth = 32; // SIMD width
+ id<MTLComputePipelineState> pipeline = nil;
+
if (ne00%4 == 0) {
while (nth < ne00/4 && nth < 256) {
nth *= 2;
}
- [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
} else {
while (nth < ne00 && nth < 1024) {
nth *= 2;
}
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
}
const float scale = ((float *) dst->op_params)[0];
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
{
const int n_past = ((int32_t *)(dst->op_params))[0];
+ id<MTLComputePipelineState> pipeline = nil;
+
if (ne00%8 == 0) {
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
} else {
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
}
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
ne00 % 32 == 0 && ne00 >= 64 &&
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
switch (src0->type) {
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
- case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
- case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break;
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
int nrows = 1;
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+ id<MTLComputePipelineState> pipeline = nil;
+
// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
nrows = 4;
} break;
case GGML_TYPE_F16:
nth1 = 1;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) {
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
nrows = ne11;
} else {
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
nrows = 4;
}
} else {
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
nrows = 4;
}
} break;
{
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
} break;
case GGML_TYPE_Q4_1:
{
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
} break;
case GGML_TYPE_Q5_0:
{
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
} break;
case GGML_TYPE_Q5_1:
{
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
} break;
case GGML_TYPE_Q8_0:
{
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_Q2_K:
{
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
} break;
case GGML_TYPE_Q3_K:
{
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
} break;
case GGML_TYPE_Q4_K:
{
nth0 = 4; //1;
nth1 = 8; //32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
} break;
case GGML_TYPE_Q5_K:
{
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
} break;
case GGML_TYPE_Q6_K:
{
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
} break;
case GGML_TYPE_IQ2_XXS:
{
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
} break;
case GGML_TYPE_IQ2_XS:
{
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
} break;
default:
{
GGML_ASSERT(ne00 >= nth0*nth1);
}
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne20 % 32 == 0 && ne20 >= 64 &&
ne11 > ne11_mm_min) {
+
+ id<MTLComputePipelineState> pipeline = nil;
+
switch (src2->type) {
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
- case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
- case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
int nrows = 1;
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+ id<MTLComputePipelineState> pipeline = nil;
+
// use custom matrix x vector kernel
switch (src2t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
} break;
case GGML_TYPE_F16:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
nth0 = 32;
nth1 = 1;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
} break;
case GGML_TYPE_Q4_0:
{
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
} break;
case GGML_TYPE_Q4_1:
{
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
} break;
case GGML_TYPE_Q5_0:
{
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
} break;
case GGML_TYPE_Q5_1:
{
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
} break;
case GGML_TYPE_Q8_0:
{
nth0 = 8;
nth1 = 8;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_Q2_K:
{
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
} break;
case GGML_TYPE_Q3_K:
{
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
} break;
case GGML_TYPE_Q4_K:
{
nth0 = 4; //1;
nth1 = 8; //32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
} break;
case GGML_TYPE_Q5_K:
{
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
} break;
case GGML_TYPE_Q6_K:
{
nth0 = 2;
nth1 = 32;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
} break;
case GGML_TYPE_IQ2_XXS:
{
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
} break;
case GGML_TYPE_IQ2_XS:
{
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
} break;
default:
{
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
} break;
case GGML_OP_GET_ROWS:
{
+ id<MTLComputePipelineState> pipeline = nil;
+
switch (src0->type) {
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
- case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
- case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
- case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
- case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
- case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
}
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
nth *= 2;
}
- [encoder setComputePipelineState:ctx->pipeline_rms_norm];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
// nth *= 2;
//}
- [encoder setComputePipelineState:ctx->pipeline_group_norm];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
const int nth = MIN(256, ne00);
- [encoder setComputePipelineState:ctx->pipeline_norm];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
- [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+ id<MTLComputePipelineState> pipeline = nil;
+
switch (src0->type) {
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
default: GGML_ASSERT(false);
};
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+ id<MTLComputePipelineState> pipeline = nil;
+
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
default: GGML_ASSERT(false);
};
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
const int sf = dst->op_params[0];
- [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
- const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
- [encoder setComputePipelineState:ctx->pipeline_pad_f32];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+ id<MTLComputePipelineState> pipeline = nil;
+
switch (order) {
- case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
- case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
+ case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
+ case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
default: GGML_ASSERT(false);
};
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
float slope;
memcpy(&slope, dst->op_params, sizeof(float));
- [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
+ id<MTLComputePipelineState> pipeline = nil;
+
switch (src0t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
switch (dstt) {
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
- case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
- //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
- //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
+ //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
+ //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
case GGML_TYPE_F16:
{
switch (dstt) {
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
default: GGML_ASSERT(false && "not implemented");
}
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
}
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
- return ggml_metal_supports_op(op);
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
- UNUSED(backend);
+ return ggml_metal_supports_op(metal_ctx, op);
}
static struct ggml_backend_i ggml_backend_metal_i = {