id<MTLDevice> device;
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
+
+ NSLock * lock;
};
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
- res->obj = library;
- res->device = device;
+ res->obj = library;
+ res->device = device;
res->pipelines = ggml_metal_pipelines_init();
+ res->lock = [NSLock new];
return res;
}
res->obj = library;
res->device = device;
res->pipelines = ggml_metal_pipelines_init();
+ res->lock = [NSLock new];
return res;
}
ggml_metal_pipelines_free(lib->pipelines);
+ [lib->lock release];
+
free(lib);
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
- return ggml_metal_pipelines_get(lib->pipelines, name);
+ [lib->lock lock];
+
+ ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
+
+ [lib->lock unlock];
+
+ return res;
}
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
- // note: the pipelines are cached in the library per device, so they are shared across all metal contexts
- ggml_critical_section_start();
+ [lib->lock lock];
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+ ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
if (res) {
- ggml_critical_section_end();
+ [lib->lock unlock];
return res;
}
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
}
if (!mtl_function) {
- ggml_critical_section_end();
+ [lib->lock unlock];
GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
if (error) {
(int) res->obj.threadExecutionWidth);
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
- ggml_critical_section_end();
+ [lib->lock unlock];
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
ggml_metal_pipelines_add(lib->pipelines, name, res);
}
- ggml_critical_section_end();
+ [lib->lock unlock];
return res;
}