]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fix thread-safety (#14300)
authorGeorgi Gerganov <redacted>
Sat, 21 Jun 2025 05:04:18 +0000 (08:04 +0300)
committerGitHub <redacted>
Sat, 21 Jun 2025 05:04:18 +0000 (08:04 +0300)
ggml-ci

ggml/src/ggml-metal/ggml-metal.m

index 4e7f373cb435a6cf6cc07b137c89c67691bfd502..19f4d59e59747452920f1e9d72908681afbde0a2 100644 (file)
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
     int            mtl_device_ref_count;
     id<MTLLibrary> mtl_library;
 
+    NSLock * mtl_lock;
+
     bool has_simdgroup_reduction;
     bool has_simdgroup_mm;
     bool has_residency_sets;
     bool has_bfloat;
     bool use_bfloat;
 
+    size_t max_size;
+
     char name[128];
 } g_ggml_ctx_dev_main = {
     /*.mtl_device              =*/ nil,
     /*.mtl_device_ref_count    =*/ 0,
     /*.mtl_library             =*/ nil,
+    /*.mtl_lock                =*/ nil,
     /*.has_simdgroup_reduction =*/ false,
     /*.has_simdgroup_mm        =*/ false,
     /*.has_residency_sets      =*/ false,
     /*.has_bfloat              =*/ false,
     /*.use_bfloat              =*/ false,
+    /*.max_size                =*/ 0,
     /*.name                    =*/ "",
 };
 
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
 static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
     assert(ctx != NULL);
 
+    if (ctx->mtl_lock == nil) {
+        ctx->mtl_lock = [[NSLock alloc] init];
+    }
+
     if (ctx->mtl_device == nil) {
         ctx->mtl_device = MTLCreateSystemDefaultDevice();
     }
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
         ctx->use_bfloat = false;
 #endif
 
+        ctx->max_size = ctx->mtl_device.maxBufferLength;
+
         strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
     }
 
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     ctx->mtl_device_ref_count--;
 
     if (ctx->mtl_device_ref_count == 0) {
+        if (ctx->mtl_lock) {
+            [ctx->mtl_lock release];
+            ctx->mtl_lock = nil;
+        }
+
         if (ctx->mtl_library) {
             [ctx->mtl_library release];
             ctx->mtl_library = nil;
@@ -977,7 +994,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
     struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
     struct ggml_backend_metal_device_context * ctx_dev = dev->context;
 
-    id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
+    id<MTLDevice> device = ctx_dev->mtl_device;
 
     GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
 
@@ -991,9 +1008,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
     ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
     // load library
-    if (ctx_dev->mtl_library == nil) {
-        ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
+    {
+        [ctx_dev->mtl_lock lock];
+
+        if (ctx_dev->mtl_library == nil) {
+            ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
+        }
+
+        [ctx_dev->mtl_lock unlock];
     }
+
     id<MTLLibrary> metal_library = ctx_dev->mtl_library;
     if (metal_library == nil) {
         GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
@@ -5284,7 +5308,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
     }
 
     ggml_backend_metal_buffer_rset_free(ctx);
-    ggml_backend_metal_device_rel(buffer->buft->device->context);
 
     if (ctx->owned) {
 #if TARGET_OS_OSX
@@ -5393,7 +5416,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
     }
 
     struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
-    id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
+
+    GGML_ASSERT(ctx_dev->mtl_device != nil);
+
+    id<MTLDevice> device = ctx_dev->mtl_device;
 
     ctx->all_data = ggml_metal_host_malloc(size_aligned);
     ctx->all_size = size_aligned;
@@ -5416,14 +5442,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
     if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
         GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
         free(ctx);
-        ggml_backend_metal_device_rel(ctx_dev);
         return NULL;
     }
 
     if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
         GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
         free(ctx);
-        ggml_backend_metal_device_rel(ctx_dev);
         return NULL;
     }
 
@@ -5434,17 +5458,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
 
 static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return 32;
+
     GGML_UNUSED(buft);
 }
 
 static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
-    id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
-    const size_t max_size = device.maxBufferLength;
-    ggml_backend_metal_device_rel(buft->device->context);
+    const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
 
     return max_size;
-
-    GGML_UNUSED(buft);
 }
 
 static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
@@ -5517,7 +5538,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
     }
 
     struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
-    id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
+
+    GGML_ASSERT(ctx_dev->mtl_device != nil);
+
+    id<MTLDevice> device = ctx_dev->mtl_device;
 
     // the buffer fits into the max buffer size allowed by the device
     if (size_aligned <= device.maxBufferLength) {
@@ -5573,7 +5597,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
     if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
         GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
         free(ctx);
-        ggml_backend_metal_device_rel(ctx_dev);
         return NULL;
     }
 
@@ -5589,10 +5612,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
 }
 
 static void ggml_backend_metal_free(ggml_backend_t backend) {
-    struct ggml_backend_metal_context        * ctx     = backend->context;
-    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+    struct ggml_backend_metal_context * ctx = backend->context;
 
-    ggml_backend_metal_device_rel(ctx_dev);
     ggml_metal_free(ctx);
 
     free(backend);
@@ -5732,6 +5753,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
 
     struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
 
+    GGML_ASSERT(ctx_dev->mtl_device != nil);
+
     return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
 }
 
@@ -5751,10 +5774,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
 }
 
 static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
-    // acq/rel just to populate ctx->name in case it hasn't been done yet
     struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
-    ggml_backend_metal_device_acq(ctx_dev);
-    ggml_backend_metal_device_rel(ctx_dev);
 
     return ctx_dev->name;
 }
@@ -5762,12 +5782,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
 static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
     if (@available(macOS 10.12, iOS 16.0, *)) {
         struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
-        id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
+        id<MTLDevice> device = ctx_dev->mtl_device;
 
         *total = device.recommendedMaxWorkingSetSize;
         *free  = *total - device.currentAllocatedSize;
-
-        ggml_backend_metal_device_rel(ctx_dev);
     } else {
         *free = 1;
         *total = 1;
@@ -5845,7 +5863,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
     }
 
     struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
-    id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
+
+    GGML_ASSERT(ctx_dev->mtl_device != nil);
+
+    id<MTLDevice> device = ctx_dev->mtl_device;
 
     // the buffer fits into the max buffer size allowed by the device
     if (size_aligned <= device.maxBufferLength) {
@@ -5901,7 +5922,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
     if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
         GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
         free(ctx);
-        ggml_backend_metal_device_rel(ctx_dev);
         return NULL;
     }
 
@@ -5915,8 +5935,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
 }
 
 static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
-            buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
+    return
+        buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
+        buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
 
     GGML_UNUSED(dev);
 }
@@ -6001,8 +6022,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
     /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
 };
 
+// called upon program exit
+static void ggml_metal_cleanup(void) {
+    ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
+}
+
+// TODO: make thread-safe
 ggml_backend_reg_t ggml_backend_metal_reg(void) {
-    // TODO: make this thread-safe somehow?
+    ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
+
+    // register cleanup callback
+    // TODO: not ideal, but not sure if there is a better way to do this in Objective-C
+    atexit(ggml_metal_cleanup);
+
     {
         g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
             /* .api_version = */ GGML_BACKEND_API_VERSION,