]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
enhance argsort for UT (llama/17573)
authorNeo Zhang Jianyu <redacted>
Tue, 2 Dec 2025 00:56:46 +0000 (08:56 +0800)
committerGeorgi Gerganov <redacted>
Thu, 11 Dec 2025 13:32:51 +0000 (15:32 +0200)
Co-authored-by: Neo Zhang <redacted>
src/ggml-sycl/ggml-sycl.cpp

index e82b51206e2a16ccd16faba82ad2c28c3ec8ee7e..a264ade0b771b31037407acb5fe73472edfa90d8 100644 (file)
@@ -1787,6 +1787,7 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
     const sycl::range<3> block_dims(1, 1, nth);
     const sycl::range<3> block_nums(1, nrows, 1);
     const size_t shared_mem = ncols_pad * sizeof(int);
+    GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
 
     if (order == GGML_SORT_ORDER_ASC) {
         stream->submit([&](sycl::handler &cgh) {
@@ -4348,6 +4349,9 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
 }
 
 static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+    ggml_backend_sycl_device_context *sycl_ctx =
+        (ggml_backend_sycl_device_context *)dev->context;
+    int device = sycl_ctx->device;
     switch (op->op) {
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
@@ -4601,8 +4605,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
-        case GGML_OP_ARGSORT:
             return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_ARGSORT:
+            return op->src[0]->ne[0] * sizeof(int) <=
+                   ggml_sycl_info().devices[device].smpbo;
         case GGML_OP_POOL_2D:
         case GGML_OP_ACC:
             return true;