]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
opencl: mark `argsort` unsupported if cols exceed workgroup limit (llama/15375)
authorlhez <redacted>
Tue, 19 Aug 2025 18:25:51 +0000 (02:25 +0800)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:53:59 +0000 (12:53 +0300)
src/ggml-opencl/ggml-opencl.cpp

index 8a2ac7e377d15f3fe7662d0a5a9a4bb7aa8bb5e2..df27501361f7fcb3f93196c7b105a7ffa08be282 100644 (file)
@@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
 
     cl_int alignment;
     size_t max_alloc_size;
+    size_t max_workgroup_size;
     bool fp16_support;
     bool has_vector_subgroup_broadcast;
     bool disable_fusion;
@@ -2218,6 +2219,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
     clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
     GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
 
+    clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);
+    GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size);
+
     // Check SVM.
     cl_device_svm_capabilities svm_caps;
     CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0));
@@ -2533,7 +2537,8 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
 }
 
 static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-    GGML_UNUSED(dev);
+    ggml_backend_opencl_device_context * dev_ctx     = (ggml_backend_opencl_device_context *)dev->context;
+    ggml_backend_opencl_context *        backend_ctx = dev_ctx->backend_ctx;
 
     switch (op->op) {
         case GGML_OP_NONE:
@@ -2708,8 +2713,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
         }
         case GGML_OP_IM2COL:
             return true;
-        case GGML_OP_ARGSORT:
-            return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ARGSORT: {
+            cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32;
+            int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
+
+            int cols = 1;
+            while (cols < op->ne[0]) {
+                cols *= 2;
+            }
+
+            return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
+        }
         case GGML_OP_SUM_ROWS:
             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
         case GGML_OP_FLASH_ATTN_EXT: