]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : check cuda and metal argsort limits and add test (llama/16323)
authorSigbjørn Skjæret <redacted>
Mon, 29 Sep 2025 09:09:00 +0000 (11:09 +0200)
committerGeorgi Gerganov <redacted>
Mon, 29 Sep 2025 12:18:12 +0000 (15:18 +0300)
* check cuda argsort limits and add test

* add metal check

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-metal/ggml-metal-device.m

index 5cd1e0d862db4fe96eeb4228555d6ffe9451edad..5a9e54721e46386df950613317363b57baddfdff 100644 (file)
@@ -3639,9 +3639,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CONV_TRANSPOSE_2D:
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM:
-        case GGML_OP_ARGSORT:
         case GGML_OP_ACC:
             return true;
+        case GGML_OP_ARGSORT:
+            // TODO: Support arbitrary column width
+            return op->src[0]->ne[0] <= 1024;
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_GROUP_NORM:
index cced0369d0226e7e269885b5762273f620bd56fa..523f9d71ba14ea96f89ae67b8df5541213b85822 100644 (file)
@@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                    (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_TIMESTEP_EMBEDDING:
-        case GGML_OP_ARGSORT:
         case GGML_OP_LEAKY_RELU:
             return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ARGSORT:
+            // TODO: Support arbitrary column width
+            return op->src[0]->ne[0] <= 1024;
         case GGML_OP_ARANGE:
             return true;
         case GGML_OP_FLASH_ATTN_EXT: