#define UNUSED(x) (void)(x)
+// globals
+
+// overload of MTLGPUFamilyMetal3 (not available in some environments)
+static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
+
+// initialized in ggml_backend_metal_reg
+static struct ggml_backend_reg g_ggml_backend_metal_reg;
+static struct ggml_backend_device g_ggml_backend_metal_device;
+
+// information about a Metal device
+// note: assumes single GPU device - the default one
+// TODO: support multiple GPU devices
+static struct ggml_backend_metal_device_context {
+ id<MTLDevice> mtl_device;
+ int mtl_device_ref_count;
+
+ bool support_simdgroup_reduction;
+ bool support_simdgroup_mm;
+
+ char name[128];
+} g_ggml_ctx_dev_main = {
+ /*.mtl_device =*/ nil,
+ /*.mtl_device_ref_count =*/ 0,
+ /*.support_simdgroup_reduction =*/ false,
+ /*.support_simdgroup_mm =*/ false,
+ /*.name =*/ "",
+};
+
+// acquire
+static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
+ assert(ctx != NULL);
+
+ if (ctx->mtl_device == nil) {
+ ctx->mtl_device = MTLCreateSystemDefaultDevice();
+
+ ctx->support_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+ ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
+
+ ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+
+ strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
+ }
+
+ ctx->mtl_device_ref_count++;
+
+ return ctx->mtl_device;
+}
+
+// release
+static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) {
+ assert(ctx != NULL);
+ assert(ctx->mtl_device_ref_count > 0);
+
+ ctx->mtl_device_ref_count--;
+
+ if (ctx->mtl_device_ref_count == 0) {
+ [ctx->mtl_device release];
+ ctx->mtl_device = nil;
+ }
+}
+
+// kernels
+
struct ggml_metal_kernel {
id<MTLComputePipelineState> pipeline;
};
};
struct ggml_backend_metal_context {
- id<MTLDevice> device;
id<MTLCommandQueue> queue;
dispatch_queue_t d_queue;
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
- bool support_simdgroup_reduction;
- bool support_simdgroup_mm;
-
// capture state
bool capture_next_compute;
bool capture_started;
return data;
}
-static struct ggml_backend_metal_context * ggml_metal_init(void) {
+static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
GGML_LOG_INFO("%s: allocating\n", __func__);
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
[devices release]; // since it was created by a *Copy* C method
#endif
- // Pick and show default Metal device
- id<MTLDevice> device = MTLCreateSystemDefaultDevice();
+ // init context
+ struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
+ struct ggml_backend_metal_device_context * ctx_dev = dev->context;
+
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
- // Configure context
- struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
- ctx->device = device;
- ctx->queue = [ctx->device newCommandQueue];
+ ctx->queue = [device newCommandQueue];
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
id<MTLLibrary> metal_library;
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
- metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
+ metal_library = [device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
//[options setFastMathEnabled:false];
- metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
+ metal_library = [device newLibraryWithSource:src options:options error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
// print MTL GPU family:
- GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
-
- const NSInteger MTLGPUFamilyMetal3 = 5001;
+ GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]);
// determine max supported GPU family
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
{
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
- if ([ctx->device supportsFamily:i]) {
+ if ([device supportsFamily:i]) {
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
break;
}
}
for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
- if ([ctx->device supportsFamily:i]) {
+ if ([device supportsFamily:i]) {
GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
break;
}
}
- for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
- if ([ctx->device supportsFamily:i]) {
- GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
+ for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
+ if ([device supportsFamily:i]) {
+ GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
break;
}
}
}
- ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
- ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
-
- ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
-
- GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
- GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
- GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+ GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
+ GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
+ GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
ctx->capture_next_compute = false;
ctx->capture_started = false;
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
if (@available(macOS 10.12, iOS 16.0, *)) {
- GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
- }
-#elif TARGET_OS_OSX
- if (ctx->device.maxTransferRate != 0) {
- GGML_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
- } else {
- GGML_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
+ GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
}
#endif
if (supported) { \
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
- kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
+ kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
[metal_function release]; \
if (error) { \
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
}
+ const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
+ const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
+
// simd_sum and simd_max requires MTLGPUFamilyApple7
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
Block_release(ctx->encode_async);
[ctx->queue release];
- [ctx->device release];
dispatch_release(ctx->d_queue);
return nil;
}
-static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
+static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
return false;
}
}
+ const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
+ const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
+
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
- return ctx->support_simdgroup_reduction;
+ return support_simdgroup_reduction;
case GGML_OP_NORM:
case GGML_OP_ROPE:
return true;
if (op->src[0]->ne[0] == 256) {
return false;
}
- return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
+ return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
return true;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
- return ctx->support_simdgroup_reduction &&
+ return support_simdgroup_reduction &&
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
case GGML_OP_CPY:
case GGML_OP_DUP:
}
static void ggml_metal_encode_node(
- struct ggml_backend_metal_context * ctx,
+ ggml_backend_t backend,
int idx,
id<MTLComputeCommandEncoder> encoder) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
struct ggml_cgraph * gf = ctx->gf;
struct ggml_tensor * node = ggml_graph_node(gf, idx);
} break;
}
- if (!ggml_metal_supports_op(ctx, dst)) {
+ if (!ggml_metal_supports_op(ctx_dev, dst)) {
GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
GGML_ABORT("unsupported op");
}
// dst->name);
//}
+ id<MTLDevice> device = ctx_dev->mtl_device;
+
switch (dst->op) {
case GGML_OP_CONCAT:
{
// the numbers below are measured on M2 Ultra for 7B and 13B models
// these numbers do not translate to other devices or model sizes
// TODO: need to find a better approach
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
+ if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
switch (src0t) {
case GGML_TYPE_F16: ne11_mm_min = 2; break;
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+ if ([device supportsFamily:MTLGPUFamilyApple7] &&
!ggml_is_transposed(src0) &&
!ggml_is_transposed(src1) &&
src1t == GGML_TYPE_F32 &&
// ne21 = n_rows
const int dst_rows = ne20*ne21;
const int dst_rows_min = n_as;
- const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
+ const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
// max size of the rowids array in the kernel shared buffer
GGML_ASSERT(dst_rows <= dst_rows_max);
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
// indirect matrix multiplication
// !!!
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+ if ([device supportsFamily:MTLGPUFamilyApple7] &&
ne00 % 32 == 0 && ne00 >= 64 &&
dst_rows > dst_rows_min) {
while (true) {
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
- if (smem > ctx->device.maxThreadgroupMemoryLength) {
+ if (smem > device.maxThreadgroupMemoryLength) {
break;
}
nsgmax *= 2;
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+ //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+ //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
static enum ggml_status ggml_metal_graph_compute(
- struct ggml_backend_metal_context * ctx,
- struct ggml_cgraph * gf) {
+ ggml_backend_t backend,
+ struct ggml_cgraph * gf) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
// number of nodes encoded by the main thread (empirically determined)
const int n_main = 128;
if (!ctx->capture_started) {
// create capture scope
- ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
+ ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
descriptor.captureObject = ctx->capture_scope;
// backend interface
-// default buffer
-static id<MTLDevice> g_backend_device = nil;
-static int g_backend_device_ref_count = 0;
-
-static id<MTLDevice> ggml_backend_metal_get_device(void) {
- if (g_backend_device == nil) {
- g_backend_device = MTLCreateSystemDefaultDevice();
- }
-
- g_backend_device_ref_count++;
-
- return g_backend_device;
-}
-
-static void ggml_backend_metal_free_device(void) {
- assert(g_backend_device_ref_count > 0);
-
- g_backend_device_ref_count--;
-
- if (g_backend_device_ref_count == 0) {
- [g_backend_device release];
- g_backend_device = nil;
- }
-}
-
static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
return "Metal";
for (int i = 0; i < ctx->n_buffers; i++) {
[ctx->buffers[i].metal release];
}
- ggml_backend_metal_free_device();
+ ggml_backend_metal_device_rel(buffer->buft->device->context);
if (ctx->owned) {
#if TARGET_OS_OSX
size_aligned += (size_page - (size_aligned % size_page));
}
- id<MTLDevice> device = ggml_backend_metal_get_device();
+ id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
ctx->all_data = ggml_metal_host_malloc(size_aligned);
ctx->all_size = size_aligned;
if (size_aligned > 0) {
ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
- length:size_aligned
- options:MTLResourceStorageModeShared
- deallocator:nil];
+ length:size_aligned
+ options:MTLResourceStorageModeShared
+ deallocator:nil];
}
}
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
free(ctx);
- ggml_backend_metal_free_device();
+ ggml_backend_metal_device_rel(buft->device->context);
return NULL;
}
}
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
- id<MTLDevice> device = ggml_backend_metal_get_device();
- size_t max_size = device.maxBufferLength;
- ggml_backend_metal_free_device();
+ id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
+ const size_t max_size = device.maxBufferLength;
+ ggml_backend_metal_device_rel(buft->device->context);
return max_size;
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
},
- /* .device = */ NULL,
+ /* .device = */ &g_ggml_backend_metal_device,
/* .context = */ NULL,
};
return &ggml_backend_buffer_type_metal;
}
-// buffer from ptr
-
+// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
size_aligned += (size_page - (size_aligned % size_page));
}
- id<MTLDevice> device = ggml_backend_metal_get_device();
+ id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
// the buffer fits into the max buffer size allowed by the device
if (size_aligned <= device.maxBufferLength) {
}
static void ggml_backend_metal_free(ggml_backend_t backend) {
- struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+ struct ggml_backend_metal_context * ctx = backend->context;
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+ ggml_backend_metal_device_rel(ctx_dev);
ggml_metal_free(ctx);
+
free(backend);
}
}
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
- struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
-
- return ggml_metal_graph_compute(metal_ctx, cgraph);
-}
-
-static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
- struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
-
- return ggml_metal_supports_op(metal_ctx, op);
-}
-
-static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
- return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
-
- UNUSED(backend);
+ return ggml_metal_graph_compute(backend, cgraph);
}
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
- ggml_metal_encode_node(ctx, idx, encoder);
+ ggml_metal_encode_node(backend, idx, encoder);
if (should_capture) {
[encoder popDebugGroup];
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_metal_graph_compute,
- /* .supports_op = */ ggml_backend_metal_supports_op,
- /* .supports_buft = */ ggml_backend_metal_supports_buft,
+ /* .supports_op = */ NULL,
+ /* .supports_buft = */ NULL,
/* .offload_op = */ NULL,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
return &guid;
}
+// TODO: remove in the future
ggml_backend_t ggml_backend_metal_init(void) {
- struct ggml_backend_metal_context * ctx = ggml_metal_init();
+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
+
+ struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
if (ctx == NULL) {
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
return NULL;
*backend = (struct ggml_backend) {
/* .guid = */ ggml_backend_metal_guid(),
/* .interface = */ ggml_backend_metal_i,
- /* .device = */ NULL,
+ /* .device = */ dev,
/* .context = */ ctx,
};
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
GGML_ASSERT(ggml_backend_is_metal(backend));
- struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
- return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
+ return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
}
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
ctx->capture_next_compute = true;
}
-ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
+// backend device
+
+static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
+ return "Metal";
-ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
- return ggml_backend_metal_init();
+ GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
+ // acq/rel just to populate ctx->name in case it hasn't been done yet
+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+ ggml_backend_metal_device_acq(ctx_dev);
+ ggml_backend_metal_device_rel(ctx_dev);
+
+ return ctx_dev->name;
+}
+
+static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+ if (@available(macOS 10.12, iOS 16.0, *)) {
+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
+
+ *total = device.recommendedMaxWorkingSetSize;
+ *free = *total - device.currentAllocatedSize;
+
+ ggml_backend_metal_device_rel(ctx_dev);
+ } else {
+ *free = 1;
+ *total = 1;
+ }
+}
+
+static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
+
+ GGML_UNUSED(dev);
+}
+
+static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+ props->name = ggml_backend_metal_device_get_name(dev);
+ props->description = ggml_backend_metal_device_get_description(dev);
+ props->type = ggml_backend_metal_device_get_type(dev);
+ ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
+ props->caps = (struct ggml_backend_dev_caps) {
+ /* .async = */ false,
+ /* .host_buffer = */ false,
+ /* .buffer_from_host_ptr = */ true,
+ /* .events = */ false,
+ };
+}
+
+static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
+ struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
+ if (ctx == NULL) {
+ GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
+ return NULL;
+ }
+
+ ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
+
+ *backend = (struct ggml_backend) {
+ /* .guid = */ ggml_backend_metal_guid(),
+ /* .interface = */ ggml_backend_metal_i,
+ /* .device = */ dev,
+ /* .context = */ ctx,
+ };
+
+ ggml_backend_metal_set_n_cb(backend, 1);
+
+ return backend;
GGML_UNUSED(params);
- GGML_UNUSED(user_data);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
+ return ggml_backend_metal_buffer_type();
+
+ GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+ struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
+
+ ctx->all_data = ptr;
+ ctx->all_size = size;
+ ctx->owned = false;
+ ctx->n_buffers = 0;
+
+ const size_t size_page = sysconf(_SC_PAGESIZE);
+
+ // page-align the data ptr
+ {
+ const uintptr_t offs = (uintptr_t) ptr % size_page;
+ ptr = (void *) ((char *) ptr - offs);
+ size += offs;
+ }
+
+ size_t size_aligned = size;
+ if ((size_aligned % size_page) != 0) {
+ size_aligned += (size_page - (size_aligned % size_page));
+ }
+
+ struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+ id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
+
+ // the buffer fits into the max buffer size allowed by the device
+ if (size_aligned <= device.maxBufferLength) {
+ ctx->buffers[ctx->n_buffers].data = ptr;
+ ctx->buffers[ctx->n_buffers].size = size;
+ ctx->buffers[ctx->n_buffers].metal = nil;
+
+ if (size_aligned > 0) {
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
+ GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
+ return false;
+ }
+ }
+
+ ggml_backend_metal_log_allocated_size(device, size_aligned);
+
+ ++ctx->n_buffers;
+ } else {
+ // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
+ // one of the views
+ const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
+ const size_t size_step = device.maxBufferLength - size_ovlp;
+ const size_t size_view = device.maxBufferLength;
+
+ for (size_t i = 0; i < size; i += size_step) {
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
+
+ ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i);
+ ctx->buffers[ctx->n_buffers].size = size_step_aligned;
+ ctx->buffers[ctx->n_buffers].metal = nil;
+
+ if (size_step_aligned > 0) {
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
+ GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
+ return false;
+ }
+ }
+
+ ggml_backend_metal_log_allocated_size(device, size_step_aligned);
+
+ if (i + size_step < size) {
+ GGML_LOG_INFO("\n");
+ }
+
+ ++ctx->n_buffers;
+ }
+ }
+
+ return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
+}
+
+static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+ struct ggml_backend_metal_device_context * ctx_dev = dev->context;
+
+ return ggml_metal_supports_op(ctx_dev, op);
+}
+
+static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
+
+ UNUSED(dev);
+}
+
+static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+ return false;
+
+ GGML_UNUSED(dev);
+ GGML_UNUSED(op);
+}
+
+static struct ggml_backend_device_i ggml_backend_metal_device_i = {
+ /* .get_name = */ ggml_backend_metal_device_get_name,
+ /* .get_description = */ ggml_backend_metal_device_get_description,
+ /* .get_memory = */ ggml_backend_metal_device_get_memory,
+ /* .get_type = */ ggml_backend_metal_device_get_type,
+ /* .get_props = */ ggml_backend_metal_device_get_props,
+ /* .init_backend = */ ggml_backend_metal_device_init,
+ /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type,
+ /* .get_host_buffer_type = */ NULL,
+ /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
+ /* .supports_op = */ ggml_backend_metal_device_supports_op,
+ /* .supports_buft = */ ggml_backend_metal_device_supports_buft,
+ /* .offload_op = */ ggml_backend_metal_device_offload_op,
+ /* .event_new = */ NULL,
+ /* .event_free = */ NULL,
+ /* .event_synchronize = */ NULL,
+};
+
+// backend registry
+
+static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
+ return "Metal";
+
+ GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
+ return 1;
+
+ GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
+ GGML_ASSERT(index == 0);
+
+ return &g_ggml_backend_metal_device;
+
+ GGML_UNUSED(reg);
+ GGML_UNUSED(index);
+}
+
+static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
+ /* .get_name = */ ggml_backend_metal_reg_get_name,
+ /* .device_count = */ ggml_backend_metal_reg_device_count,
+ /* .device_get = */ ggml_backend_metal_reg_device_get,
+ /* .get_proc_address = */ NULL,
+};
+
+ggml_backend_reg_t ggml_backend_metal_reg(void) {
+ // TODO: make this thread-safe somehow?
+ {
+ g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
+ /* .iface = */ ggml_backend_metal_reg_i,
+ /* .context = */ NULL,
+ };
+
+ g_ggml_backend_metal_device = (struct ggml_backend_device) {
+ /* .iface = */ ggml_backend_metal_device_i,
+ /* .reg = */ &g_ggml_backend_metal_reg,
+ /* .context = */ &g_ggml_ctx_dev_main,
+ };
+ }
+
+ return &g_ggml_backend_metal_reg;
}