cl_int alignment;
size_t max_alloc_size;
+ size_t max_workgroup_size;
bool fp16_support;
bool has_vector_subgroup_broadcast;
bool disable_fusion;
clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
+ clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);
+ GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size);
+
// Check SVM.
cl_device_svm_capabilities svm_caps;
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0));
}
static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
- GGML_UNUSED(dev);
+ ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context;
+ ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx;
switch (op->op) {
case GGML_OP_NONE:
}
case GGML_OP_IM2COL:
return true;
- case GGML_OP_ARGSORT:
- return op->src[0]->type == GGML_TYPE_F32;
+ case GGML_OP_ARGSORT: {
+ cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32;
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
+
+ int cols = 1;
+ while (cols < op->ne[0]) {
+ cols *= 2;
+ }
+
+ return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
+ }
case GGML_OP_SUM_ROWS:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT: