]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: use CUB for arbitary size argsort (llama/16754)
authorAman Gupta <redacted>
Fri, 24 Oct 2025 12:46:19 +0000 (20:46 +0800)
committerGeorgi Gerganov <redacted>
Sat, 1 Nov 2025 07:41:35 +0000 (09:41 +0200)
src/ggml-cuda/argsort.cu
src/ggml-cuda/ggml-cuda.cu

index 607ded8558b45b9b1b40ea9bdd15299c4d883875..6e7b90d42783f3a0d69c2eb4a9b2a7d4ec3236d0 100644 (file)
@@ -1,5 +1,81 @@
 #include "argsort.cuh"
 
+#ifdef GGML_CUDA_USE_CUB
+#    include <cub/cub.cuh>
+using namespace cub;
+#endif  // GGML_CUDA_USE_CUB
+
+static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
+    const int col = blockIdx.x * blockDim.x + threadIdx.x;
+    const int row = blockIdx.y;
+
+    if (col < ncols && row < nrows) {
+        indices[row * ncols + col] = col;
+    }
+}
+
+static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx <= nrows) {
+        offsets[idx] = idx * ncols;
+    }
+}
+
+#ifdef GGML_CUDA_USE_CUB
+static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
+                                     const float *    x,
+                                     int *            dst,
+                                     const int        ncols,
+                                     const int        nrows,
+                                     ggml_sort_order  order,
+                                     cudaStream_t     stream) {
+    ggml_cuda_pool_alloc<int>   temp_indices_alloc(pool, ncols * nrows);
+    ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
+    ggml_cuda_pool_alloc<int>   offsets_alloc(pool, nrows + 1);
+
+    int *   temp_indices = temp_indices_alloc.get();
+    float * temp_keys    = temp_keys_alloc.get();
+    int *   d_offsets    = offsets_alloc.get();
+
+    static const int block_size = 256;
+    const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
+    init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
+
+    const dim3 offset_grid((nrows + block_size - 1) / block_size);
+    init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
+
+    cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
+
+    size_t temp_storage_bytes = 0;
+
+    if (order == GGML_SORT_ORDER_ASC) {
+        DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                            temp_indices, dst,                                  // values (indices)
+                                            ncols * nrows, nrows,                            // num items, num segments
+                                            d_offsets, d_offsets + 1, 0, sizeof(float) * 8,  // all bits
+                                            stream);
+    } else {
+        DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
+                                                      dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
+                                                      sizeof(float) * 8, stream);
+    }
+
+    ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
+    void *                        d_temp_storage = temp_storage_alloc.get();
+
+    if (order == GGML_SORT_ORDER_ASC) {
+        DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
+                                            ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
+                                            stream);
+    } else {
+        DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
+                                                      temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
+                                                      0, sizeof(float) * 8, stream);
+    }
+}
+#endif  // GGML_CUDA_USE_CUB
+
+// Bitonic sort implementation
 template<typename T>
 static inline __device__ void ggml_cuda_swap(T & a, T & b) {
     T tmp = a;
@@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
     return n;
 }
 
-static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+static void argsort_f32_i32_cuda_bitonic(const float *   x,
+                                         int *           dst,
+                                         const int       ncols,
+                                         const int       nrows,
+                                         ggml_sort_order order,
+                                         cudaStream_t    stream) {
     // bitonic sort requires ncols to be power of 2
     const int ncols_pad = next_power_of_2(ncols);
 
@@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
     GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
 
     if (order == GGML_SORT_ORDER_ASC) {
-        k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+        k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
+            <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
     } else if (order == GGML_SORT_ORDER_DESC) {
-        k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+        k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
+            <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
     } else {
         GGML_ABORT("fatal error");
     }
@@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
 
-    argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
+#ifdef GGML_CUDA_USE_CUB
+    const int    ncols_pad      = next_power_of_2(ncols);
+    const size_t shared_mem     = ncols_pad * sizeof(int);
+    const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+    if (shared_mem > max_shared_mem || ncols > 1024) {
+        ggml_cuda_pool & pool = ctx.pool();
+        argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
+    } else {
+        argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
+    }
+#else
+    argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
+#endif
 }
index f5a6a751acfd5293e75389d59ac9224a32015e36..bc396b521af0751aa27eacba3d01d888588b79aa 100644 (file)
@@ -3642,8 +3642,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SUM:
             return ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_ARGSORT:
-            // TODO: Support arbitrary column width
+#ifndef GGML_CUDA_USE_CUB
             return op->src[0]->ne[0] <= 1024;
+#else
+            return true;
+#endif
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_GROUP_NORM: