]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: Replace init_offsets kernel with iterators in cub-based argsort (llama/18930)
authorOliver Simons <redacted>
Tue, 20 Jan 2026 12:11:01 +0000 (13:11 +0100)
committerGeorgi Gerganov <redacted>
Fri, 30 Jan 2026 11:49:29 +0000 (13:49 +0200)
* CUDA: Replace `init_offsets` with iterators in argsort

This is a QOL improvement, saving us the cost of materializing the
iterator

* Remove unnecessary include from top-k.cu

src/ggml-cuda/argsort.cu
src/ggml-cuda/top-k.cu

index 57c8a99a286c7d5a294f7f0e9f83430b33907824..cf7a44f7adc61058384a49ac8f2af59e4c8a59a1 100644 (file)
@@ -14,12 +14,6 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
     }
 }
 
-static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
-    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
-    if (idx <= nrows) {
-        offsets[idx] = idx * ncols;
-    }
-}
 
 #ifdef GGML_CUDA_USE_CUB
 void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
@@ -31,18 +25,15 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                               cudaStream_t     stream) {
     ggml_cuda_pool_alloc<int>   temp_indices_alloc(pool, ncols * nrows);
     ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
-    ggml_cuda_pool_alloc<int>   offsets_alloc(pool, nrows + 1);
 
     int *   temp_indices = temp_indices_alloc.get();
     float * temp_keys    = temp_keys_alloc.get();
-    int *   d_offsets    = offsets_alloc.get();
 
     static const int block_size = 256;
     const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
     init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
 
-    const dim3 offset_grid((nrows + block_size - 1) / block_size);
-    init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
+    auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
 
     CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
 
@@ -57,7 +48,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
             DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
                                            temp_indices, dst,                                  // values (indices)
                                            ncols * nrows, nrows,  // num items, num segments
-                                           d_offsets, d_offsets + 1, stream);
+                                           offset_iterator, offset_iterator + 1, stream);
         }
     } else {
         if (nrows == 1) {
@@ -66,7 +57,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                                                  ncols, 0, sizeof(float) * 8, stream);
         } else {
             DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
-                                                     dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+                                                     dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
+                                                     stream);
         }
     }
 
@@ -80,7 +72,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                                        ncols, 0, sizeof(float) * 8, stream);
         } else {
             DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
-                                           ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+                                           ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
         }
     } else {
         if (nrows == 1) {
@@ -89,8 +81,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                                                  ncols, 0, sizeof(float) * 8, stream);
         } else {
             DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
-                                                     temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
-                                                     stream);
+                                                     temp_indices, dst, ncols * nrows, nrows, offset_iterator,
+                                                     offset_iterator + 1, stream);
         }
     }
 }
index 318ac38691e61f590610074b82e41ef7c1a43e3e..785a18389f294339ad3d527fd4bfe190a834d736 100644 (file)
@@ -4,7 +4,6 @@
 #ifdef GGML_CUDA_USE_CUB
 #    include <cub/cub.cuh>
 #    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
-#        include <cuda/iterator>
 #        define CUB_TOP_K_AVAILABLE
 using namespace cub;
 #    endif  // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2