]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add metal backend registry / device (llama/9713)
authorGeorgi Gerganov <redacted>
Mon, 7 Oct 2024 15:27:51 +0000 (18:27 +0300)
committerGeorgi Gerganov <redacted>
Wed, 16 Oct 2024 08:28:38 +0000 (11:28 +0300)
* ggml : add metal backend registry / device

ggml-ci

* metal : fix names [no ci]

* metal : global registry and device instances

ggml-ci

* cont : alternative initialization of global objects

ggml-ci

* llama : adapt to backend changes

ggml-ci

* fixes

* metal : fix indent

* metal : fix build when MTLGPUFamilyApple3 is not available

ggml-ci

* fix merge

* metal : avoid unnecessary singleton accesses

ggml-ci

* metal : minor fix [no ci]

* metal : g_state -> g_ggml_ctx_dev_main [no ci]

* metal : avoid reference of device context in the backend context

ggml-ci

* metal : minor [no ci]

* metal : fix maxTransferRate check

* metal : remove transfer rate stuff

---------

Co-authored-by: slaren <redacted>
include/ggml-backend.h
include/ggml-metal.h
src/ggml-backend.cpp
src/ggml-cuda.cu
src/ggml-metal.m

index 4d7d2716e7a774bdd8a8c738871183c4fe36cc66..152b9adb08fa2c74940de822a4b47ac497421bb1 100644 (file)
@@ -127,6 +127,8 @@ extern "C" {
         bool async;
         // pinned host buffer
         bool host_buffer;
+        // creating buffers from host ptr
+        bool buffer_from_host_ptr;
         // event synchronization
         bool events;
     };
index c3ec572b2dc832de51b4fda9ec1c9025d1dacdb3..b8d3f678b71576eee3a085cbf0deffc76ddfc11a 100644 (file)
@@ -43,7 +43,9 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
 
 GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
 
-GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
+GGML_DEPRECATED(
+        GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
+        "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
 
 GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
 
@@ -57,6 +59,8 @@ GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int fam
 // capture all command buffers committed the next time `ggml_backend_graph_compute` is called
 GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
 
+GGML_API ggml_backend_reg_t ggml_backend_metal_reg(void);
+
 #ifdef __cplusplus
 }
 #endif
index 0551764fe3fb08515bb076f1ed6fc879c92eec29..4f3e9374cadfb7a7a68524c65bab9ee486d6e55f 100644 (file)
@@ -463,6 +463,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
 }
 
 void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) {
+    memset(props, 0, sizeof(*props));
     device->iface.get_props(device, props);
 }
 
@@ -479,6 +480,10 @@ ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t devic
 }
 
 ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
+    if (device->iface.get_host_buffer_type == NULL) {
+        return NULL;
+    }
+
     return device->iface.get_host_buffer_type(device);
 }
 
@@ -525,6 +530,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
 #include "ggml-cuda.h"
 #endif
 
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
 struct ggml_backend_registry {
     std::vector<ggml_backend_reg_t> backends;
     std::vector<ggml_backend_dev_t> devices;
@@ -533,10 +542,13 @@ struct ggml_backend_registry {
 #ifdef GGML_USE_CUDA
         register_backend(ggml_backend_cuda_reg());
 #endif
+#ifdef GGML_USE_METAL
+        register_backend(ggml_backend_metal_reg());
+#endif
 
         register_backend(ggml_backend_cpu_reg());
 
-        // TODO: sycl, metal, vulkan, kompute, cann
+        // TODO: sycl, vulkan, kompute, cann
     }
 
     void register_backend(ggml_backend_reg_t reg) {
@@ -1118,9 +1130,10 @@ static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggm
     props->type        = ggml_backend_cpu_device_get_type(dev);
     ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
     props->caps = {
-        /* async       */ false,
-        /* host_buffer */ false,
-        /* events      */ false,
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ true,
+        /* .events                = */ false,
     };
 }
 
index 5b6f605b0089bed5df81919187aee20fcf6180b1..edb61abdfeacfcd6651ba6be6aff22dcb1decbb5 100644 (file)
@@ -2920,9 +2920,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
 #endif
 
     props->caps = {
-        /* async       */ true,
-        /* host_buffer */ host_buffer,
-        /* events      */ events,
+        /* .async                 = */ true,
+        /* .host_buffer           = */ host_buffer,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ events,
     };
 }
 
index 08598c28b1e4f444bd9ddbdaa8e3436a831fa502..172a0f925d316968d695160cc5dfede12374dd98 100644 (file)
 
 #define UNUSED(x) (void)(x)
 
+// globals
+
+// overload of MTLGPUFamilyMetal3 (not available in some environments)
+static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
+
+// initialized in ggml_backend_metal_reg
+static struct ggml_backend_reg    g_ggml_backend_metal_reg;
+static struct ggml_backend_device g_ggml_backend_metal_device;
+
+// information about a Metal device
+// note: assumes single GPU device - the default one
+// TODO: support multiple GPU devices
+static struct ggml_backend_metal_device_context {
+    id<MTLDevice> mtl_device;
+    int           mtl_device_ref_count;
+
+    bool support_simdgroup_reduction;
+    bool support_simdgroup_mm;
+
+    char name[128];
+} g_ggml_ctx_dev_main = {
+    /*.mtl_device                  =*/ nil,
+    /*.mtl_device_ref_count        =*/ 0,
+    /*.support_simdgroup_reduction =*/ false,
+    /*.support_simdgroup_mm        =*/ false,
+    /*.name                        =*/ "",
+};
+
+// acquire
+static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
+    assert(ctx != NULL);
+
+    if (ctx->mtl_device == nil) {
+        ctx->mtl_device = MTLCreateSystemDefaultDevice();
+
+        ctx->support_simdgroup_reduction  = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+        ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
+
+        ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+
+        strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
+    }
+
+    ctx->mtl_device_ref_count++;
+
+    return ctx->mtl_device;
+}
+
+// release
+static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) {
+    assert(ctx != NULL);
+    assert(ctx->mtl_device_ref_count > 0);
+
+    ctx->mtl_device_ref_count--;
+
+    if (ctx->mtl_device_ref_count == 0) {
+        [ctx->mtl_device release];
+        ctx->mtl_device = nil;
+    }
+}
+
+// kernels
+
 struct ggml_metal_kernel {
     id<MTLComputePipelineState> pipeline;
 };
@@ -214,16 +277,12 @@ enum ggml_metal_kernel_type {
 };
 
 struct ggml_backend_metal_context {
-    id<MTLDevice>       device;
     id<MTLCommandQueue> queue;
 
     dispatch_queue_t d_queue;
 
     struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
 
-    bool support_simdgroup_reduction;
-    bool support_simdgroup_mm;
-
     // capture state
     bool capture_next_compute;
     bool capture_started;
@@ -280,7 +339,7 @@ static void * ggml_metal_host_malloc(size_t n) {
     return data;
 }
 
-static struct ggml_backend_metal_context * ggml_metal_init(void) {
+static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
     GGML_LOG_INFO("%s: allocating\n", __func__);
 
 #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
@@ -292,14 +351,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
     [devices release]; // since it was created by a *Copy* C method
 #endif
 
-    // Pick and show default Metal device
-    id<MTLDevice> device = MTLCreateSystemDefaultDevice();
+    // init context
+    struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
+    struct ggml_backend_metal_device_context * ctx_dev = dev->context;
+
+    id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
     GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
 
-    // Configure context
-    struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
-    ctx->device = device;
-    ctx->queue  = [ctx->device newCommandQueue];
+    ctx->queue  = [device newCommandQueue];
     ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
     id<MTLLibrary> metal_library;
@@ -332,7 +391,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
             NSURL * libURL = [NSURL fileURLWithPath:path_lib];
             GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
 
-            metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
+            metal_library = [device newLibraryWithURL:libURL error:&error];
             if (error) {
                 GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
                 return NULL;
@@ -382,7 +441,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
 
                 //[options setFastMathEnabled:false];
 
-                metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
+                metal_library = [device newLibraryWithSource:src options:options error:&error];
                 if (error) {
                     GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
                     return NULL;
@@ -392,44 +451,37 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
     }
 
     // print MTL GPU family:
-    GGML_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
-
-    const NSInteger MTLGPUFamilyMetal3 = 5001;
+    GGML_LOG_INFO("%s: GPU name:   %s\n", __func__, [[device name] UTF8String]);
 
     // determine max supported GPU family
     // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
     // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
     {
         for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
-            if ([ctx->device supportsFamily:i]) {
+            if ([device supportsFamily:i]) {
                 GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d  (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
                 break;
             }
         }
 
         for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
-            if ([ctx->device supportsFamily:i]) {
+            if ([device supportsFamily:i]) {
                 GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
                 break;
             }
         }
 
-        for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
-            if ([ctx->device supportsFamily:i]) {
-                GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d  (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
+        for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
+            if ([device supportsFamily:i]) {
+                GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d  (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
                 break;
             }
         }
     }
 
-    ctx->support_simdgroup_reduction  = [ctx->device supportsFamily:MTLGPUFamilyApple7];
-    ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
-
-    ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
-
-    GGML_LOG_INFO("%s: simdgroup reduction support   = %s\n",       __func__, ctx->support_simdgroup_reduction ? "true" : "false");
-    GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n",       __func__, ctx->support_simdgroup_mm ? "true" : "false");
-    GGML_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    GGML_LOG_INFO("%s: simdgroup reduction support   = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
+    GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
+    GGML_LOG_INFO("%s: hasUnifiedMemory              = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
 
     ctx->capture_next_compute = false;
     ctx->capture_started = false;
@@ -443,13 +495,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
 
 #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
     if (@available(macOS 10.12, iOS 16.0, *)) {
-        GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
-    }
-#elif TARGET_OS_OSX
-    if (ctx->device.maxTransferRate != 0) {
-        GGML_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
-    } else {
-        GGML_LOG_INFO("%s: maxTransferRate               = built-in GPU\n", __func__);
+        GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
     }
 #endif
 
@@ -470,7 +516,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
         if (supported) { \
             struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
             id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
-            kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
+            kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
             [metal_function release]; \
             if (error) { \
                 GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@@ -481,6 +527,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
             GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
         }
 
+        const bool support_simdgroup_mm        = ctx_dev->support_simdgroup_mm;
+        const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
+
         // simd_sum and simd_max requires MTLGPUFamilyApple7
 
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                           add,                            true);
@@ -507,10 +556,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                  gelu_quick_4,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                          silu,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                        silu_4,                         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,                 diag_mask_inf,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,               diag_mask_inf_8,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,                  get_rows_f32,                   true);
@@ -535,101 +584,101 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,               get_rows_iq4_nl,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,               get_rows_iq4_xs,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                  get_rows_i32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                  ssm_conv_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                  ssm_scan_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                 rope_norm_f32,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true);
@@ -643,14 +692,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,          argsort_f32_i32_desc,           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                leaky_relu_f32,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        ctx->support_simdgroup_mm);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        support_simdgroup_mm);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
@@ -684,7 +733,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
     Block_release(ctx->encode_async);
 
     [ctx->queue release];
-    [ctx->device release];
 
     dispatch_release(ctx->d_queue);
 
@@ -742,13 +790,16 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
     return nil;
 }
 
-static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
+static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
     for (size_t i = 0, n = 3; i < n; ++i) {
         if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
             return false;
         }
     }
 
+    const bool support_simdgroup_mm        = ctx_dev->support_simdgroup_mm;
+    const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
+
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
@@ -786,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
         case GGML_OP_SOFT_MAX:
         case GGML_OP_RMS_NORM:
         case GGML_OP_GROUP_NORM:
-            return ctx->support_simdgroup_reduction;
+            return support_simdgroup_reduction;
         case GGML_OP_NORM:
         case GGML_OP_ROPE:
             return true;
@@ -812,13 +863,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
             if (op->src[0]->ne[0] == 256) {
                 return false;
             }
-            return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
+            return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
             return true;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
-            return ctx->support_simdgroup_reduction &&
+            return support_simdgroup_reduction &&
                 (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
         case GGML_OP_CPY:
         case GGML_OP_DUP:
@@ -862,9 +913,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
 }
 
 static void ggml_metal_encode_node(
-     struct ggml_backend_metal_context * ctx,
+                        ggml_backend_t   backend,
                                    int   idx,
           id<MTLComputeCommandEncoder>   encoder) {
+    struct ggml_backend_metal_context        * ctx     = backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
     struct ggml_cgraph * gf = ctx->gf;
 
     struct ggml_tensor * node = ggml_graph_node(gf, idx);
@@ -894,7 +948,7 @@ static void ggml_metal_encode_node(
             } break;
     }
 
-    if (!ggml_metal_supports_op(ctx, dst)) {
+    if (!ggml_metal_supports_op(ctx_dev, dst)) {
         GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
         GGML_ABORT("unsupported op");
     }
@@ -967,6 +1021,8 @@ static void ggml_metal_encode_node(
     //            dst->name);
     //}
 
+    id<MTLDevice> device = ctx_dev->mtl_device;
+
     switch (dst->op) {
         case GGML_OP_CONCAT:
             {
@@ -1675,7 +1731,7 @@ static void ggml_metal_encode_node(
                 // the numbers below are measured on M2 Ultra for 7B and 13B models
                 // these numbers do not translate to other devices or model sizes
                 // TODO: need to find a better approach
-                        if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
+                        if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
                             switch (src0t) {
                                 case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
                                 case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
@@ -1695,7 +1751,7 @@ static void ggml_metal_encode_node(
 
                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                        if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                        if ([device supportsFamily:MTLGPUFamilyApple7] &&
                                 !ggml_is_transposed(src0) &&
                                 !ggml_is_transposed(src1) &&
                                 src1t == GGML_TYPE_F32 &&
@@ -1990,7 +2046,7 @@ static void ggml_metal_encode_node(
                 // ne21 = n_rows
                 const int dst_rows = ne20*ne21;
                 const int dst_rows_min = n_as;
-                const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
+                const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
 
                 // max size of the rowids array in the kernel shared buffer
                 GGML_ASSERT(dst_rows <= dst_rows_max);
@@ -2001,7 +2057,7 @@ static void ggml_metal_encode_node(
                 // TODO: for now, always use mat-vec kernels until we figure out how to improve the
                 //       indirect matrix multiplication
                 // !!!
-                if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                if ([device supportsFamily:MTLGPUFamilyApple7] &&
                         ne00 % 32 == 0 && ne00 >= 64 &&
                         dst_rows > dst_rows_min) {
 
@@ -2840,7 +2896,7 @@ static void ggml_metal_encode_node(
 
                     while (true) {
                         const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
-                        if (smem > ctx->device.maxThreadgroupMemoryLength) {
+                        if (smem > device.maxThreadgroupMemoryLength) {
                             break;
                         }
                         nsgmax *= 2;
@@ -2852,8 +2908,8 @@ static void ggml_metal_encode_node(
 
                     const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
 
-                    //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
-                    GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+                    //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
+                    GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
 
                     [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
 
@@ -2878,8 +2934,8 @@ static void ggml_metal_encode_node(
 
                     const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
 
-                    //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
-                    GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+                    //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
+                    GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
                     [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
 
                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
@@ -2954,8 +3010,11 @@ static void ggml_metal_encode_node(
 }
 
 static enum ggml_status ggml_metal_graph_compute(
-        struct ggml_backend_metal_context * ctx,
-                       struct ggml_cgraph * gf) {
+            ggml_backend_t   backend,
+        struct ggml_cgraph * gf) {
+    struct ggml_backend_metal_context        * ctx     = backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
     // number of nodes encoded by the main thread (empirically determined)
     const int n_main = 128;
 
@@ -2983,7 +3042,7 @@ static enum ggml_status ggml_metal_graph_compute(
 
             if (!ctx->capture_started) {
                 // create capture scope
-                ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
+                ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
 
                 MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
                 descriptor.captureObject = ctx->capture_scope;
@@ -3087,31 +3146,6 @@ static enum ggml_status ggml_metal_graph_compute(
 
 // backend interface
 
-// default buffer
-static id<MTLDevice> g_backend_device = nil;
-static int g_backend_device_ref_count = 0;
-
-static id<MTLDevice> ggml_backend_metal_get_device(void) {
-    if (g_backend_device == nil) {
-        g_backend_device = MTLCreateSystemDefaultDevice();
-    }
-
-    g_backend_device_ref_count++;
-
-    return g_backend_device;
-}
-
-static void ggml_backend_metal_free_device(void) {
-    assert(g_backend_device_ref_count > 0);
-
-    g_backend_device_ref_count--;
-
-    if (g_backend_device_ref_count == 0) {
-        [g_backend_device release];
-        g_backend_device = nil;
-    }
-}
-
 static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
     return "Metal";
 
@@ -3124,7 +3158,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
     for (int i = 0; i < ctx->n_buffers; i++) {
         [ctx->buffers[i].metal release];
     }
-    ggml_backend_metal_free_device();
+    ggml_backend_metal_device_rel(buffer->buft->device->context);
 
     if (ctx->owned) {
 #if TARGET_OS_OSX
@@ -3227,7 +3261,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
         size_aligned += (size_page - (size_aligned % size_page));
     }
 
-    id<MTLDevice> device = ggml_backend_metal_get_device();
+    id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
 
     ctx->all_data = ggml_metal_host_malloc(size_aligned);
     ctx->all_size = size_aligned;
@@ -3241,16 +3275,16 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
 
         if (size_aligned > 0) {
             ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
-                            length:size_aligned
-                            options:MTLResourceStorageModeShared
-                            deallocator:nil];
+                                            length:size_aligned
+                                            options:MTLResourceStorageModeShared
+                                            deallocator:nil];
         }
     }
 
     if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
         GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
         free(ctx);
-        ggml_backend_metal_free_device();
+        ggml_backend_metal_device_rel(buft->device->context);
         return NULL;
     }
 
@@ -3265,9 +3299,9 @@ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_t
 }
 
 static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
-    id<MTLDevice> device = ggml_backend_metal_get_device();
-    size_t max_size = device.maxBufferLength;
-    ggml_backend_metal_free_device();
+    id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
+    const size_t max_size = device.maxBufferLength;
+    ggml_backend_metal_device_rel(buft->device->context);
 
     return max_size;
 
@@ -3290,15 +3324,14 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
             /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
             /* .is_host          = */ ggml_backend_metal_buffer_type_is_host,
         },
-        /* .device  = */ NULL,
+        /* .device  = */ &g_ggml_backend_metal_device,
         /* .context = */ NULL,
     };
 
     return &ggml_backend_buffer_type_metal;
 }
 
-// buffer from ptr
-
+// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
 ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
     struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
 
@@ -3321,7 +3354,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
         size_aligned += (size_page - (size_aligned % size_page));
     }
 
-    id<MTLDevice> device = ggml_backend_metal_get_device();
+    id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
 
     // the buffer fits into the max buffer size allowed by the device
     if (size_aligned <= device.maxBufferLength) {
@@ -3386,8 +3419,12 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
 }
 
 static void ggml_backend_metal_free(ggml_backend_t backend) {
-    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+    struct ggml_backend_metal_context        * ctx     = backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+    ggml_backend_metal_device_rel(ctx_dev);
     ggml_metal_free(ctx);
+
     free(backend);
 }
 
@@ -3398,21 +3435,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
 }
 
 static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
-
-    return ggml_metal_graph_compute(metal_ctx, cgraph);
-}
-
-static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
-
-    return ggml_metal_supports_op(metal_ctx, op);
-}
-
-static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
-    return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
-
-    UNUSED(backend);
+    return ggml_metal_graph_compute(backend, cgraph);
 }
 
 static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
@@ -3459,7 +3482,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
                 [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
             }
 
-            ggml_metal_encode_node(ctx, idx, encoder);
+            ggml_metal_encode_node(backend, idx, encoder);
 
             if (should_capture) {
                 [encoder popDebugGroup];
@@ -3487,8 +3510,8 @@ static struct ggml_backend_i ggml_backend_metal_i = {
     /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_metal_graph_compute,
-    /* .supports_op             = */ ggml_backend_metal_supports_op,
-    /* .supports_buft           = */ ggml_backend_metal_supports_buft,
+    /* .supports_op             = */ NULL,
+    /* .supports_buft           = */ NULL,
     /* .offload_op              = */ NULL,
     /* .event_record            = */ NULL,
     /* .event_wait              = */ NULL,
@@ -3499,8 +3522,11 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
     return &guid;
 }
 
+// TODO: remove in the future
 ggml_backend_t ggml_backend_metal_init(void) {
-    struct ggml_backend_metal_context * ctx = ggml_metal_init();
+    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
+
+    struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
     if (ctx == NULL) {
         GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
         return NULL;
@@ -3511,7 +3537,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
     *backend = (struct ggml_backend) {
         /* .guid      = */ ggml_backend_metal_guid(),
         /* .interface = */ ggml_backend_metal_i,
-        /* .device    = */ NULL,
+        /* .device    = */ dev,
         /* .context   = */ ctx,
     };
 
@@ -3536,9 +3562,9 @@ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_ca
 bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
     GGML_ASSERT(ggml_backend_is_metal(backend));
 
-    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
 
-    return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
+    return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
 }
 
 void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
@@ -3548,11 +3574,246 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
     ctx->capture_next_compute = true;
 }
 
-ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
+// backend device
+
+static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
+    return "Metal";
 
-ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
-    return ggml_backend_metal_init();
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
+    // acq/rel just to populate ctx->name in case it hasn't been done yet
+    struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+    ggml_backend_metal_device_acq(ctx_dev);
+    ggml_backend_metal_device_rel(ctx_dev);
+
+    return ctx_dev->name;
+}
+
+static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    if (@available(macOS 10.12, iOS 16.0, *)) {
+        struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+        id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
+
+        *total = device.recommendedMaxWorkingSetSize;
+        *free  = *total - device.currentAllocatedSize;
+
+        ggml_backend_metal_device_rel(ctx_dev);
+    } else {
+        *free = 1;
+        *total = 1;
+    }
+}
+
+static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_metal_device_get_name(dev);
+    props->description = ggml_backend_metal_device_get_description(dev);
+    props->type        = ggml_backend_metal_device_get_type(dev);
+    ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = (struct ggml_backend_dev_caps) {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ true,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
+    struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
+    if (ctx == NULL) {
+        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
+        return NULL;
+    }
+
+    ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
+
+    *backend = (struct ggml_backend) {
+        /* .guid      = */ ggml_backend_metal_guid(),
+        /* .interface = */ ggml_backend_metal_i,
+        /* .device    = */ dev,
+        /* .context   = */ ctx,
+    };
+
+    ggml_backend_metal_set_n_cb(backend, 1);
+
+    return backend;
 
     GGML_UNUSED(params);
-    GGML_UNUSED(user_data);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_metal_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
+
+    ctx->all_data = ptr;
+    ctx->all_size = size;
+    ctx->owned = false;
+    ctx->n_buffers = 0;
+
+    const size_t size_page = sysconf(_SC_PAGESIZE);
+
+    // page-align the data ptr
+    {
+        const uintptr_t offs = (uintptr_t) ptr % size_page;
+        ptr  = (void *) ((char *) ptr - offs);
+        size += offs;
+    }
+
+    size_t size_aligned = size;
+    if ((size_aligned % size_page) != 0) {
+        size_aligned += (size_page - (size_aligned % size_page));
+    }
+
+    struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+    id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
+
+    // the buffer fits into the max buffer size allowed by the device
+    if (size_aligned <= device.maxBufferLength) {
+        ctx->buffers[ctx->n_buffers].data  = ptr;
+        ctx->buffers[ctx->n_buffers].size  = size;
+        ctx->buffers[ctx->n_buffers].metal = nil;
+
+        if (size_aligned > 0) {
+            ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+            if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
+                return false;
+            }
+        }
+
+        ggml_backend_metal_log_allocated_size(device, size_aligned);
+
+        ++ctx->n_buffers;
+    } else {
+        // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
+        // one of the views
+        const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
+        const size_t size_step = device.maxBufferLength - size_ovlp;
+        const size_t size_view = device.maxBufferLength;
+
+        for (size_t i = 0; i < size; i += size_step) {
+            const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
+
+            ctx->buffers[ctx->n_buffers].data  = (void *) ((uint8_t *) ptr + i);
+            ctx->buffers[ctx->n_buffers].size  = size_step_aligned;
+            ctx->buffers[ctx->n_buffers].metal = nil;
+
+            if (size_step_aligned > 0) {
+                ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+                if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                    GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
+                    return false;
+                }
+            }
+
+            ggml_backend_metal_log_allocated_size(device, size_step_aligned);
+
+            if (i + size_step < size) {
+                GGML_LOG_INFO("\n");
+            }
+
+            ++ctx->n_buffers;
+        }
+    }
+
+    return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
+}
+
+static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    struct ggml_backend_metal_device_context * ctx_dev = dev->context;
+
+    return ggml_metal_supports_op(ctx_dev, op);
+}
+
+static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
+
+    UNUSED(dev);
+}
+
+static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    return false;
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(op);
+}
+
+static struct ggml_backend_device_i ggml_backend_metal_device_i = {
+    /* .get_name             = */ ggml_backend_metal_device_get_name,
+    /* .get_description      = */ ggml_backend_metal_device_get_description,
+    /* .get_memory           = */ ggml_backend_metal_device_get_memory,
+    /* .get_type             = */ ggml_backend_metal_device_get_type,
+    /* .get_props            = */ ggml_backend_metal_device_get_props,
+    /* .init_backend         = */ ggml_backend_metal_device_init,
+    /* .get_buffer_type      = */ ggml_backend_metal_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
+    /* .supports_op          = */ ggml_backend_metal_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_metal_device_supports_buft,
+    /* .offload_op           = */ ggml_backend_metal_device_offload_op,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// backend registry
+
+static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
+    return "Metal";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
+    return 1;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    return &g_ggml_backend_metal_device;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(index);
+}
+
+static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
+    /* .get_name         = */ ggml_backend_metal_reg_get_name,
+    /* .device_count     = */ ggml_backend_metal_reg_device_count,
+    /* .device_get       = */ ggml_backend_metal_reg_device_get,
+    /* .get_proc_address = */ NULL,
+};
+
+ggml_backend_reg_t ggml_backend_metal_reg(void) {
+    // TODO: make this thread-safe somehow?
+    {
+        g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
+            /* .iface   = */ ggml_backend_metal_reg_i,
+            /* .context = */ NULL,
+        };
+
+        g_ggml_backend_metal_device = (struct ggml_backend_device) {
+            /* .iface   = */ ggml_backend_metal_device_i,
+            /* .reg     = */ &g_ggml_backend_metal_reg,
+            /* .context = */ &g_ggml_ctx_dev_main,
+        };
+    }
+
+    return &g_ggml_backend_metal_reg;
 }