]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
sampling : add support for backend sampling (llama/17004)
authorDaniel Bevenius <redacted>
Sun, 4 Jan 2026 20:22:16 +0000 (21:22 +0100)
committerGeorgi Gerganov <redacted>
Wed, 14 Jan 2026 07:11:59 +0000 (09:11 +0200)
* sampling : add support for backend sampling

This commit adds support for performing sampling operations on the
backend (e.g. GPU) as part of the model computation graph.

The motivation for this feature is to enable sampling to be performed
directly on the backend as part of the computation graph being executed,
allowing for some or all of the sampling to be done on the backend.

For example, the backend sampler chain might select/sample a token
directly in which case only the sampled token needs to be transferred
from device memory to host memory.

It is also possible for the backend samplers to perform filtering of
the logits, or compute and filter the probability distribution, in
which case only the filtered logits or probabilites need to be
transferred back to system memory for further processing by CPU
samplers.

Currently the backend sampling works in a similar manner to how
pooling works, it is a function that is called by build_graph and the
sampler operations become part of the models computation graph.

* llama-cli : add backend sampler configuration

* server : add backend sampling options/configuration

* webui : add backend sampling options

* ggml : add initial cumsum implementation for CUDA

* sampling : enable all backend sampler tests

This commit enables all exisiting backend sampler tests in the
test-backend-sampler. Previously, some tests were disabled because
there were missing ggml operation implementations.

* graph : do not include llama-model.h

* sampling : always expose sampled_ids

This commit precomputes and caches the full-vocab token id list in
llama_context's constructor, so llama_get_backend_sampled_token_ids_ith
always returns a valid pointer.

The motivation for this is that this enables both common/sampling.cpp
and src/llama-sampling.cpp can simplify their logic.

Not all backends samplers that process logits need to set the
sampled_tokens_id as they may not change the order of the logits, for
example the temperature sampler only scales the logits but does not
change their order. Simliar the logit bias sampler only adds bias to
specific token ids but does not change the order of the logits. In
these cases there will not be a device to host copy of the sampled
token ids, and this is the use case where having this precomputed
list is useful.

* sampling : ensure at most one output token per seq

This commit adds a check in the batch allocator to ensure that when
backend sampling is enabled, at most one output token is specified per
sequence.

* CUDA: Optimize argsort for gpu-based token sampling

Argsort is used for top-k currently. WE optimize argsort by 2 things:

1. Use `DeviceRadixSort` for single-row/sequence to parallelize it
   across our SMs
2. Use `DeviceSegmentedSort` for multi-row/sequence as this is the
   correct entrypoint (the function chooses different execution paths,
   it contains `DeviceSegmentedRadixSort` as one of the paths and will
   choose the best one according to heuristics.
   https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedSort.html#overview

Some perf numbers for a RTX PRO 6000:

On the kernel level, tested with
`GGML_CUDA_DISABLE_GRAPHS=1 ./test-backend-ops -o ARGSORT perf`
Before:
```
  ARGSORT(type=f32,ne=[65000,16,1,1],order=0):                  4130 runs -   359.24 us/run
  ARGSORT(type=f32,ne=[200000,1,1,1],order=0):                  8192 runs -   861.34 us/run
  ARGSORT(type=f32,ne=[200000,16,1,1],order=0):                 1343 runs -  1020.01 us/run
```

After:
```
  ARGSORT(type=f32,ne=[65000,16,1,1],order=0):                  4130 runs -   312.41 us/run
  ARGSORT(type=f32,ne=[200000,1,1,1],order=0):                 16384 runs -    63.48 us/run
  ARGSORT(type=f32,ne=[200000,16,1,1],order=0):                 1343 runs -   874.36 us/run
```

ggml/src/ggml-cuda/CMakeLists.txt
ggml/src/ggml-cuda/argsort.cu
ggml/src/ggml-cuda/argsort.cuh
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/cumsum.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/softmax.cu
ggml/src/ggml-cuda/top-k.cu [new file with mode: 0644]
ggml/src/ggml-cuda/top-k.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/vendors/hip.h
ggml/src/ggml-cuda/vendors/musa.h

index ae8f963f6991fc9d24546451b753e1f39a832ac2..dcc004134d07b1ec39ce46e4e17862c2064a360f 100644 (file)
@@ -54,6 +54,20 @@ if (CUDAToolkit_FOUND)
 
     enable_language(CUDA)
 
+    # TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
+    if (GGML_CUDA_CUB_3DOT2)
+        include(FetchContent)
+
+        FetchContent_Declare(
+            CCCL
+            GIT_REPOSITORY https://github.com/nvidia/cccl.git
+            GIT_TAG        v3.2.0-rc2
+            GIT_SHALLOW    TRUE
+        )
+
+        FetchContent_MakeAvailable(CCCL)
+    endif()
+
     # Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa.
     # 12X is forwards-compatible, 12Xa is not.
     # Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.
@@ -143,6 +157,9 @@ if (CUDAToolkit_FOUND)
             # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
             target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
         else ()
+            if (GGML_CUDA_CUB_3DOT2)
+                target_link_libraries(ggml-cuda PRIVATE  CCCL::CCCL)
+            endif()
             if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
                 target_link_libraries(ggml-cuda PRIVATE  CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
             else()
@@ -150,6 +167,9 @@ if (CUDAToolkit_FOUND)
             endif()
         endif()
     else()
+        if (GGML_CUDA_CUB_3DOT2)
+            target_link_libraries(ggml-cuda PRIVATE  CCCL::CCCL)
+        endif()
         target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
     endif()
 
@@ -218,6 +238,10 @@ if (CUDAToolkit_FOUND)
 
     if (NOT MSVC)
         list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
+    else()
+        # CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC
+        # https://github.com/NVIDIA/cccl/pull/6827
+        list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)
     endif()
 
     list(JOIN   CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED)  # pass host compiler flags as a single argument
index 99669200ffdb52792bdcb3d1cb0f22f083b450d1..57c8a99a286c7d5a294f7f0e9f83430b33907824 100644 (file)
@@ -22,15 +22,15 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr
 }
 
 #ifdef GGML_CUDA_USE_CUB
-static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
-                                     const float *    x,
-                                     int *            dst,
-                                     const int        ncols,
-                                     const int        nrows,
-                                     ggml_sort_order  order,
-                                     cudaStream_t     stream) {
-    ggml_cuda_pool_alloc<int>   temp_indices_alloc(pool, ((size_t) ncols) * nrows);
-    ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ((size_t) ncols) * nrows);
+void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
+                              const float *    x,
+                              int *            dst,
+                              const int        ncols,
+                              const int        nrows,
+                              ggml_sort_order  order,
+                              cudaStream_t     stream) {
+    ggml_cuda_pool_alloc<int>   temp_indices_alloc(pool, ncols * nrows);
+    ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
     ggml_cuda_pool_alloc<int>   offsets_alloc(pool, nrows + 1);
 
     int *   temp_indices = temp_indices_alloc.get();
@@ -49,28 +49,49 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
     size_t temp_storage_bytes = 0;
 
     if (order == GGML_SORT_ORDER_ASC) {
-        DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
-                                            temp_indices, dst,                                  // values (indices)
-                                            ncols * nrows, nrows,                            // num items, num segments
-                                            d_offsets, d_offsets + 1, 0, sizeof(float) * 8,  // all bits
-                                            stream);
+        if (nrows == 1) {
+            DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                       temp_indices, dst,                                  // values (indices)
+                                       ncols, 0, sizeof(float) * 8, stream);
+        } else {
+            DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                           temp_indices, dst,                                  // values (indices)
+                                           ncols * nrows, nrows,  // num items, num segments
+                                           d_offsets, d_offsets + 1, stream);
+        }
     } else {
-        DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
-                                                      dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
-                                                      sizeof(float) * 8, stream);
+        if (nrows == 1) {
+            DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                                 temp_indices, dst,                                  // values (indices)
+                                                 ncols, 0, sizeof(float) * 8, stream);
+        } else {
+            DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
+                                                     dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+        }
     }
 
     ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
     void *                        d_temp_storage = temp_storage_alloc.get();
 
     if (order == GGML_SORT_ORDER_ASC) {
-        DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
-                                            ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
-                                            stream);
+        if (nrows == 1) {
+            DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                       temp_indices, dst,  // values (indices)
+                                       ncols, 0, sizeof(float) * 8, stream);
+        } else {
+            DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
+                                           ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+        }
     } else {
-        DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
-                                                      temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
-                                                      0, sizeof(float) * 8, stream);
+        if (nrows == 1) {
+            DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                                 temp_indices, dst,                                  // values (indices)
+                                                 ncols, 0, sizeof(float) * 8, stream);
+        } else {
+            DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
+                                                     temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
+                                                     stream);
+        }
     }
 }
 #endif  // GGML_CUDA_USE_CUB
@@ -141,12 +162,12 @@ static int next_power_of_2(int x) {
     return n;
 }
 
-static void argsort_f32_i32_cuda_bitonic(const float *   x,
-                                         int *           dst,
-                                         const int       ncols,
-                                         const int       nrows,
-                                         ggml_sort_order order,
-                                         cudaStream_t    stream) {
+void argsort_f32_i32_cuda_bitonic(const float *   x,
+                                  int *           dst,
+                                  const int       ncols,
+                                  const int       nrows,
+                                  ggml_sort_order order,
+                                  cudaStream_t    stream) {
     // bitonic sort requires ncols to be power of 2
     const int ncols_pad = next_power_of_2(ncols);
 
index 68a001547ffdb1cb7c2ec5b1a69aedd7d3da4767..22b7306f20201d81009327e8ff9f6eb0a029007e 100644 (file)
@@ -1,3 +1,19 @@
 #include "common.cuh"
 
 void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+#ifdef GGML_CUDA_USE_CUB
+void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
+                              const float *    x,
+                              int *            dst,
+                              const int        ncols,
+                              const int        nrows,
+                              ggml_sort_order  order,
+                              cudaStream_t     stream);
+#endif  // GGML_CUDA_USE_CUB
+void argsort_f32_i32_cuda_bitonic(const float *   x,
+                                  int *           dst,
+                                  const int       ncols,
+                                  const int       nrows,
+                                  ggml_sort_order order,
+                                  cudaStream_t    stream);
index 78502057a8403341e20a44a6c401bc1edff69077..55f2f46086d286bad81b56cc28d5e38bc30a9424 100644 (file)
@@ -950,15 +950,16 @@ struct ggml_cuda_device_info {
     int device_count;
 
     struct cuda_device_info {
-        int     cc;                 // compute capability
-        int     nsm;                // number of streaming multiprocessors
-        size_t  smpb;               // max. shared memory per block
-        size_t  smpbo;              // max. shared memory per block (with opt-in)
-        bool    integrated;         // Device is integrated as opposed to discrete
-        bool    vmm;                // virtual memory support
-        size_t  vmm_granularity;    // granularity of virtual memory
+        int     cc;                             // compute capability
+        int     nsm;                            // number of streaming multiprocessors
+        size_t  smpb;                           // max. shared memory per block
+        size_t  smpbo;                          // max. shared memory per block (with opt-in)
+        bool    integrated;                     // Device is integrated as opposed to discrete
+        bool    vmm;                            // virtual memory support
+        size_t  vmm_granularity;                // granularity of virtual memory
         size_t  total_vram;
-        int     warp_size;          // Number of threads in a dispatch
+        int     warp_size;                      // Number of threads in a dispatch
+        bool    supports_cooperative_launch;    // whether cooperative launch is supported
     };
 
     cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
index 3bd1394c51aff04cf2dd53e331ca3dc2f4b2d733..def9c32955f2ddd8ccd92a8e84e903d6e99ed1ef 100644 (file)
@@ -5,7 +5,7 @@
 #include "ggml.h"
 
 #ifdef GGML_CUDA_USE_CUB
-#   include <cub/block/block_scan.cuh>
+#   include <cub/cub.cuh>
 #endif // GGML_CUDA_USE_CUB
 
 template<typename T, int BLOCK_SIZE>
@@ -185,9 +185,34 @@ static __global__ void cumsum_kernel(
     }
 }
 
+#ifdef GGML_CUDA_USE_CUB
+template <typename T>
+static void cumsum_cub(ggml_cuda_pool & pool,
+                       const T *        src,
+                       T *              dst,
+                       int64_t          ne,
+                       cudaStream_t     stream) {
+    size_t tmp_size = 0;
+
+    // Query how much temp storage CUDA UnBound (CUB) needs
+    cub::DeviceScan::InclusiveSum(nullptr,   // d_temp_storage (null = just query size)
+                                  tmp_size,  // reference to size (will be set by CUB)
+                                  src,       // input pointer
+                                  dst,       // output pointer
+                                  ne,        // number of elements
+                                  stream     // CUDA stream to use
+    );
+
+    ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
+
+    // Perform the inclusive scan
+    cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
+}
+#endif // GGML_CUDA_USE_CUB
+
 template<typename T>
 static void cumsum_cuda(
-        const T * src, T * dst,
+        [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
         const int64_t  nb0,  const int64_t nb1, const int64_t  nb2, const int64_t  nb3,
@@ -201,6 +226,15 @@ static void cumsum_cuda(
 
     if (is_contiguous) {
         use_cub = true;
+        const int64_t nrows = ne01 * ne02 * ne03;
+        // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
+        // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
+        if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
+            for (int i=0; i<nrows; i++) {
+                cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
+            }
+            return;
+        }
     }
 #endif // GGML_CUDA_USE_CUB
     dim3 grid_dims(ne01, ne02, ne03);
@@ -239,7 +273,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         case GGML_TYPE_F32:
             {
                 cumsum_cuda(
-                    (const float *)src0->data, (float *)dst->data,
+                    ctx, (const float *)src0->data, (float *)dst->data,
                     src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
                     src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
                     dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
index 80d983f9eef078374c4aa40bcde8bf6419f08653..1bbca225d2b87e13e9f61fdab616ae8910b41651 100644 (file)
@@ -19,6 +19,7 @@
 #include "ggml-cuda/count-equal.cuh"
 #include "ggml-cuda/cpy.cuh"
 #include "ggml-cuda/cross-entropy-loss.cuh"
+#include "ggml-cuda/cumsum.cuh"
 #include "ggml-cuda/diagmask.cuh"
 #include "ggml-cuda/diag.cuh"
 #include "ggml-cuda/fattn.cuh"
@@ -44,6 +45,7 @@
 #include "ggml-cuda/ssm-scan.cuh"
 #include "ggml-cuda/sum.cuh"
 #include "ggml-cuda/sumrows.cuh"
+#include "ggml-cuda/top-k.cuh"
 #include "ggml-cuda/mean.cuh"
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/topk-moe.cuh"
@@ -231,6 +233,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.devices[id].nsm        = prop.multiProcessorCount;
         info.devices[id].smpb       = prop.sharedMemPerBlock;
         info.devices[id].warp_size  = prop.warpSize;
+
+#ifndef GGML_USE_MUSA
+        int supports_coop_launch = 0;
+        CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
+        info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
+#else
+        info.devices[id].supports_cooperative_launch = false;
+#endif // !(GGML_USE_MUSA)
 #if defined(GGML_USE_HIP)
         info.devices[id].smpbo = prop.sharedMemPerBlock;
 
@@ -2677,6 +2687,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SUM:
             ggml_cuda_op_sum(ctx, dst);
             break;
+        case GGML_OP_CUMSUM:
+            ggml_cuda_op_cumsum(ctx, dst);
+            break;
         case GGML_OP_SUM_ROWS:
             ggml_cuda_op_sum_rows(ctx, dst);
             break;
@@ -2689,6 +2702,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SSM_SCAN:
             ggml_cuda_op_ssm_scan(ctx, dst);
             break;
+        case GGML_OP_TOP_K:
+            ggml_cuda_op_top_k(ctx, dst);
+            break;
         case GGML_OP_ARGSORT:
             ggml_cuda_op_argsort(ctx, dst);
             break;
@@ -2698,9 +2714,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_CROSS_ENTROPY_LOSS:
             ggml_cuda_cross_entropy_loss(ctx, dst);
             break;
-        case GGML_OP_CUMSUM:
-            ggml_cuda_op_cumsum(ctx, dst);
-            break;
         case GGML_OP_TRI:
             ggml_cuda_op_tri(ctx, dst);
             break;
@@ -4626,6 +4639,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             return true;
         case GGML_OP_SUM:
             return ggml_is_contiguous_rows(op->src[0]);
+        case GGML_OP_TOP_K:
         case GGML_OP_ARGSORT:
 #ifndef GGML_CUDA_USE_CUB
             return op->src[0]->ne[0] <= 1024;
index eeacde0bdb126c173231b4bf748ab2da1a0d9cd6..1ae84ebf63036daa8d286e17070175e25b1c8604 100644 (file)
@@ -1,6 +1,14 @@
 #include "common.cuh"
 #include "ggml.h"
 #include "softmax.cuh"
+
+#ifdef GGML_USE_HIP
+#include <hip/hip_cooperative_groups.h>
+#else
+#include <cooperative_groups.h>
+#include <cooperative_groups/reduce.h>
+#endif // GGML_USE_HIP
+
 #include <cstdint>
 #include <utility>
 
@@ -160,6 +168,156 @@ static __global__ void soft_max_f32(
         dst[col] = vals[col] * inv_sum;
     }
 }
+
+
+// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
+static __device__ float two_stage_warp_reduce_max(float val) {
+    val = warp_reduce_max(val);
+    if (blockDim.x > WARP_SIZE) {
+        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
+        __shared__ float local_vals[32];
+        const int        warp_id = threadIdx.x / WARP_SIZE;
+        const int        lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            local_vals[warp_id] = val;
+        }
+        __syncthreads();
+        val = -INFINITY;
+        if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
+            val = local_vals[lane_id];
+        }
+        return warp_reduce_max(val);
+    } else {
+        return val;
+    }
+}
+
+static __device__ float two_stage_warp_reduce_sum(float val) {
+    val = warp_reduce_sum(val);
+    if (blockDim.x > WARP_SIZE) {
+        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
+        __shared__ float local_vals[32];
+        const int        warp_id = threadIdx.x / WARP_SIZE;
+        const int        lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            local_vals[warp_id] = val;
+        }
+        __syncthreads();
+        val = 0.0f;
+        if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
+            val = local_vals[lane_id];
+        }
+        return warp_reduce_sum(val);
+    } else {
+        return val;
+    }
+}
+
+// TODO: Template to allow keeping ncols in registers if they fit
+static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
+                                                                float * __restrict__ dst,
+                                                                float * __restrict__ tmp_maxs,
+                                                                float * __restrict__ tmp_sums,
+                                                                const soft_max_params p) {
+    namespace cg = cooperative_groups;
+
+    const cg::grid_group g = cg::this_grid();
+
+    const int tid               = threadIdx.x;
+    const int col_start         = blockIdx.x * blockDim.x + tid;
+    const int n_elem_per_thread = 4;
+
+    float     local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
+    float     local_max                     = -INFINITY;
+    const int step_size                     = gridDim.x * blockDim.x;
+
+    // Compute thread-local max
+    for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            const int idx = col + i * step_size;
+            local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
+        }
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            local_max = fmaxf(local_max, local_vals[i]);
+        }
+        col += step_size * n_elem_per_thread;
+    }
+
+    // Compute CTA-level max
+    local_max = two_stage_warp_reduce_max(local_max);
+
+    // Store CTA-level max to GMEM
+    if (tid == 0) {
+        tmp_maxs[blockIdx.x] = local_max;
+    }
+    g.sync();
+
+    // Compute compute global max from CTA-level maxs
+    assert(gridDim.x < blockDim.x);  // currently we only support this case
+    if (tid < gridDim.x) {
+        local_max = tmp_maxs[tid];
+    } else {
+        local_max = -INFINITY;
+    }
+    local_max = two_stage_warp_reduce_max(local_max);
+
+    // Compute softmax dividends, accumulate divisor
+    float tmp_expf = 0.0f;
+    for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            const int idx = col + i * step_size;
+            local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
+        }
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            const int idx = col + i * step_size;
+            if (idx < p.ncols) {
+                const float tmp = expf(local_vals[i] - local_max);
+                tmp_expf += tmp;
+                dst[idx] = tmp;
+            }
+        }
+        col += step_size * n_elem_per_thread;
+    }
+
+    // Reduce divisor within CTA
+    tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+
+    // Store CTA-level sum to GMEM
+    if (tid == 0) {
+        tmp_sums[blockIdx.x] = tmp_expf;
+    }
+    g.sync();
+
+    // Compute global sum from CTA-level sums
+    if (tid < gridDim.x) {
+        tmp_expf = tmp_sums[tid];
+    } else {
+        tmp_expf = 0.0f;
+    }
+    tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+
+    // Divide dividend by global sum + store data
+    for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            const int idx = col + i * step_size;
+            local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
+        }
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            const int idx = col + i * step_size;
+            if (idx < p.ncols) {
+                dst[idx] = local_vals[i] / tmp_expf;
+            }
+        }
+        col += step_size * n_elem_per_thread;
+    }
+}
+
 #ifdef __clang__
 #pragma clang diagnostic pop
 #endif // __clang__
@@ -216,9 +374,31 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float
     soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
 }
 
+__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
+                                                     float * __restrict__ dst,
+                                                     float * __restrict__ tmp_maxs,
+                                                     float * __restrict__ tmp_sums,
+                                                     const soft_max_params p)
+// We loop over all instead of parallelizing across gridDim.y as cooperative groups
+// currently only support synchronizing the complete grid if not launched as a cluster group
+// (which requires CC > 9.0)
+// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization
+// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
+{
+    for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
+        soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,
+                                                 tmp_sums, p);
+    }
+}
 
-template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
+template <typename T>
+static void soft_max_f32_cuda(const float *                                x,
+                              const T *                                    mask,
+                              const float *                                sinks,
+                              float *                                      dst,
+                              const soft_max_params &                      params,
+                              cudaStream_t                                 stream,
+                              [[maybe_unused]] ggml_backend_cuda_context & ctx) {
     int nth = WARP_SIZE;
     const int64_t ncols_x = params.ncols;
 
@@ -236,8 +416,25 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sin
     if (nbytes_shared <= smpbo) {
         launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
     } else {
-        const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
+        // Parallelize across SMs for top-p/dist-sampling
+        // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
+        // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
+        if (ggml_cuda_info().devices[id].supports_cooperative_launch &&
+            ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr &&
+            params.scale == 1.0f && params.max_bias == 0.0f) {
+            ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
+            ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
+
+            void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,
+                                     (void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(&params) };
+            CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
+                                                   dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
+                                                   dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
+        } else {
+            const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
+            soft_max_f32<false, 0, 0>
+                <<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
+        }
     }
 }
 
@@ -315,9 +512,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     params.m1 = m1;
 
     if (use_f16) {
-        soft_max_f32_cuda(src0_d, (const half  *) src1_d, (const float *) src2_d, dst_d, params, stream);
+        soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
     } else {
-        soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
+        soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
     }
 }
 
diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu
new file mode 100644 (file)
index 0000000..318ac38
--- /dev/null
@@ -0,0 +1,96 @@
+#include "argsort.cuh"
+#include "top-k.cuh"
+
+#ifdef GGML_CUDA_USE_CUB
+#    include <cub/cub.cuh>
+#    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
+#        include <cuda/iterator>
+#        define CUB_TOP_K_AVAILABLE
+using namespace cub;
+#    endif  // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
+#endif      // GGML_CUDA_USE_CUB
+
+#ifdef CUB_TOP_K_AVAILABLE
+
+static void top_k_cub(ggml_cuda_pool & pool,
+                      const float *    src,
+                      int *            dst,
+                      const int        ncols,
+                      const int        k,
+                      cudaStream_t     stream) {
+    auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
+                                                 cuda::execution::output_ordering::unsorted);
+    auto stream_env   = cuda::stream_ref{ stream };
+    auto env          = cuda::std::execution::env{ stream_env, requirements };
+
+    auto indexes_in = cuda::make_counting_iterator(0);
+
+    size_t temp_storage_bytes = 0;
+    DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
+                         env);
+
+    ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
+    void *                        d_temp_storage = temp_storage_alloc.get();
+
+    DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
+                         ncols, k, env);
+}
+
+#elif defined(GGML_CUDA_USE_CUB)  // CUB_TOP_K_AVAILABLE
+
+static int next_power_of_2(int x) {
+    int n = 1;
+    while (n < x) {
+        n *= 2;
+    }
+    return n;
+}
+
+#endif                            // CUB_TOP_K_AVAILABLE
+
+void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0   = dst->src[0];
+    const float *       src0_d = (const float *) src0->data;
+    int *               dst_d  = (int *) dst->data;
+    cudaStream_t        stream = ctx.stream();
+
+    // are these asserts truly necessary?
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t    ncols = src0->ne[0];
+    const int64_t    nrows = ggml_nrows(src0);
+    const int64_t    k     = dst->ne[0];
+    ggml_cuda_pool & pool  = ctx.pool();
+#ifdef CUB_TOP_K_AVAILABLE
+    // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
+    // https://github.com/NVIDIA/cccl/issues/6391
+    // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
+    for (int i = 0; i < nrows; i++) {
+        top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
+    }
+#elif defined(GGML_CUDA_USE_CUB)  // CUB_TOP_K_AVAILABLE
+    // Fall back to argsort + copy
+    const int    ncols_pad      = next_power_of_2(ncols);
+    const size_t shared_mem     = ncols_pad * sizeof(int);
+    const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+    ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
+    int *                     tmp_dst = temp_dst_alloc.get();
+
+    if (shared_mem > max_shared_mem || ncols > 1024) {
+        argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+    } else {
+        argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+    }
+    CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
+                                 cudaMemcpyDeviceToDevice, stream));
+#else                             // GGML_CUDA_USE_CUB
+    ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
+    int *                     tmp_dst = temp_dst_alloc.get();
+    argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+    CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
+                                 cudaMemcpyDeviceToDevice, stream));
+#endif
+}
diff --git a/ggml/src/ggml-cuda/top-k.cuh b/ggml/src/ggml-cuda/top-k.cuh
new file mode 100644 (file)
index 0000000..f4d8f61
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 951a88d5678e5a562f862a06e3b9140e706884a7..016b04e5a0c6e2de49209f351577fd0125a62920 100644 (file)
 #define cublasSgemm hipblasSgemm
 #define cublasStatus_t hipblasStatus_t
 #define cublasOperation_t hipblasOperation_t
+#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
 #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
 #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
+#define cudaDeviceGetAttribute hipDeviceGetAttribute
 #define cudaDeviceProp hipDeviceProp_t
 #define cudaDeviceSynchronize hipDeviceSynchronize
 #define cudaError_t hipError_t
@@ -70,6 +72,7 @@
 #define cudaHostRegisterPortable hipHostRegisterPortable
 #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
 #define cudaHostUnregister hipHostUnregister
+#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel
 #define cudaLaunchHostFunc hipLaunchHostFunc
 #define cudaMalloc hipMalloc
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
index 221e67f96a72f81115b3bf395da1e7e09eb5fe30..1abb8acfd4b9f7510e98121f700a300a7e8693a6 100644 (file)
@@ -61,6 +61,7 @@
 #define cudaHostRegisterPortable musaHostRegisterPortable
 #define cudaHostRegisterReadOnly musaHostRegisterReadOnly
 #define cudaHostUnregister musaHostUnregister
+#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel
 #define cudaLaunchHostFunc musaLaunchHostFunc
 #define cudaMalloc musaMalloc
 #define cudaMallocHost musaMallocHost