enable_language(CUDA)
+ # TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
+ if (GGML_CUDA_CUB_3DOT2)
+ include(FetchContent)
+
+ FetchContent_Declare(
+ CCCL
+ GIT_REPOSITORY https://github.com/nvidia/cccl.git
+ GIT_TAG v3.2.0-rc2
+ GIT_SHALLOW TRUE
+ )
+
+ FetchContent_MakeAvailable(CCCL)
+ endif()
+
# Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa.
# 12X is forwards-compatible, 12Xa is not.
# Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
else ()
+ if (GGML_CUDA_CUB_3DOT2)
+ target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
+ endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
endif()
endif()
else()
+ if (GGML_CUDA_CUB_3DOT2)
+ target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
+ endif()
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
endif()
if (NOT MSVC)
list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
+ else()
+ # CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC
+ # https://github.com/NVIDIA/cccl/pull/6827
+ list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)
endif()
list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
}
#ifdef GGML_CUDA_USE_CUB
-static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
- const float * x,
- int * dst,
- const int ncols,
- const int nrows,
- ggml_sort_order order,
- cudaStream_t stream) {
- ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ((size_t) ncols) * nrows);
- ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ((size_t) ncols) * nrows);
+void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
+ const float * x,
+ int * dst,
+ const int ncols,
+ const int nrows,
+ ggml_sort_order order,
+ cudaStream_t stream) {
+ ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
+ ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
int * temp_indices = temp_indices_alloc.get();
size_t temp_storage_bytes = 0;
if (order == GGML_SORT_ORDER_ASC) {
- DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
- temp_indices, dst, // values (indices)
- ncols * nrows, nrows, // num items, num segments
- d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
- stream);
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols * nrows, nrows, // num items, num segments
+ d_offsets, d_offsets + 1, stream);
+ }
} else {
- DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
- dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
- sizeof(float) * 8, stream);
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
+ dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+ }
}
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();
if (order == GGML_SORT_ORDER_ASC) {
- DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
- ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
- stream);
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
+ ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+ }
} else {
- DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
- temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
- 0, sizeof(float) * 8, stream);
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
+ temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
+ stream);
+ }
}
}
#endif // GGML_CUDA_USE_CUB
return n;
}
-static void argsort_f32_i32_cuda_bitonic(const float * x,
- int * dst,
- const int ncols,
- const int nrows,
- ggml_sort_order order,
- cudaStream_t stream) {
+void argsort_f32_i32_cuda_bitonic(const float * x,
+ int * dst,
+ const int ncols,
+ const int nrows,
+ ggml_sort_order order,
+ cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
#include "common.cuh"
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+#ifdef GGML_CUDA_USE_CUB
+void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
+ const float * x,
+ int * dst,
+ const int ncols,
+ const int nrows,
+ ggml_sort_order order,
+ cudaStream_t stream);
+#endif // GGML_CUDA_USE_CUB
+void argsort_f32_i32_cuda_bitonic(const float * x,
+ int * dst,
+ const int ncols,
+ const int nrows,
+ ggml_sort_order order,
+ cudaStream_t stream);
int device_count;
struct cuda_device_info {
- int cc; // compute capability
- int nsm; // number of streaming multiprocessors
- size_t smpb; // max. shared memory per block
- size_t smpbo; // max. shared memory per block (with opt-in)
- bool integrated; // Device is integrated as opposed to discrete
- bool vmm; // virtual memory support
- size_t vmm_granularity; // granularity of virtual memory
+ int cc; // compute capability
+ int nsm; // number of streaming multiprocessors
+ size_t smpb; // max. shared memory per block
+ size_t smpbo; // max. shared memory per block (with opt-in)
+ bool integrated; // Device is integrated as opposed to discrete
+ bool vmm; // virtual memory support
+ size_t vmm_granularity; // granularity of virtual memory
size_t total_vram;
- int warp_size; // Number of threads in a dispatch
+ int warp_size; // Number of threads in a dispatch
+ bool supports_cooperative_launch; // whether cooperative launch is supported
};
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
#include "ggml.h"
#ifdef GGML_CUDA_USE_CUB
-# include <cub/block/block_scan.cuh>
+# include <cub/cub.cuh>
#endif // GGML_CUDA_USE_CUB
template<typename T, int BLOCK_SIZE>
}
}
+#ifdef GGML_CUDA_USE_CUB
+template <typename T>
+static void cumsum_cub(ggml_cuda_pool & pool,
+ const T * src,
+ T * dst,
+ int64_t ne,
+ cudaStream_t stream) {
+ size_t tmp_size = 0;
+
+ // Query how much temp storage CUDA UnBound (CUB) needs
+ cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size)
+ tmp_size, // reference to size (will be set by CUB)
+ src, // input pointer
+ dst, // output pointer
+ ne, // number of elements
+ stream // CUDA stream to use
+ );
+
+ ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
+
+ // Perform the inclusive scan
+ cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
+}
+#endif // GGML_CUDA_USE_CUB
+
template<typename T>
static void cumsum_cuda(
- const T * src, T * dst,
+ [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
if (is_contiguous) {
use_cub = true;
+ const int64_t nrows = ne01 * ne02 * ne03;
+ // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
+ // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
+ if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
+ for (int i=0; i<nrows; i++) {
+ cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
+ }
+ return;
+ }
}
#endif // GGML_CUDA_USE_CUB
dim3 grid_dims(ne01, ne02, ne03);
case GGML_TYPE_F32:
{
cumsum_cuda(
- (const float *)src0->data, (float *)dst->data,
+ ctx, (const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
#include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
+#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/diag.cuh"
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
+#include "ggml-cuda/top-k.cuh"
#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/topk-moe.cuh"
info.devices[id].nsm = prop.multiProcessorCount;
info.devices[id].smpb = prop.sharedMemPerBlock;
info.devices[id].warp_size = prop.warpSize;
+
+#ifndef GGML_USE_MUSA
+ int supports_coop_launch = 0;
+ CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
+ info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
+#else
+ info.devices[id].supports_cooperative_launch = false;
+#endif // !(GGML_USE_MUSA)
#if defined(GGML_USE_HIP)
info.devices[id].smpbo = prop.sharedMemPerBlock;
case GGML_OP_SUM:
ggml_cuda_op_sum(ctx, dst);
break;
+ case GGML_OP_CUMSUM:
+ ggml_cuda_op_cumsum(ctx, dst);
+ break;
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
break;
+ case GGML_OP_TOP_K:
+ ggml_cuda_op_top_k(ctx, dst);
+ break;
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
- case GGML_OP_CUMSUM:
- ggml_cuda_op_cumsum(ctx, dst);
- break;
case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst);
break;
return true;
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
+ case GGML_OP_TOP_K:
case GGML_OP_ARGSORT:
#ifndef GGML_CUDA_USE_CUB
return op->src[0]->ne[0] <= 1024;
#include "common.cuh"
#include "ggml.h"
#include "softmax.cuh"
+
+#ifdef GGML_USE_HIP
+#include <hip/hip_cooperative_groups.h>
+#else
+#include <cooperative_groups.h>
+#include <cooperative_groups/reduce.h>
+#endif // GGML_USE_HIP
+
#include <cstdint>
#include <utility>
dst[col] = vals[col] * inv_sum;
}
}
+
+
+// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
+static __device__ float two_stage_warp_reduce_max(float val) {
+ val = warp_reduce_max(val);
+ if (blockDim.x > WARP_SIZE) {
+ assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
+ __shared__ float local_vals[32];
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ local_vals[warp_id] = val;
+ }
+ __syncthreads();
+ val = -INFINITY;
+ if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
+ val = local_vals[lane_id];
+ }
+ return warp_reduce_max(val);
+ } else {
+ return val;
+ }
+}
+
+static __device__ float two_stage_warp_reduce_sum(float val) {
+ val = warp_reduce_sum(val);
+ if (blockDim.x > WARP_SIZE) {
+ assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
+ __shared__ float local_vals[32];
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ local_vals[warp_id] = val;
+ }
+ __syncthreads();
+ val = 0.0f;
+ if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
+ val = local_vals[lane_id];
+ }
+ return warp_reduce_sum(val);
+ } else {
+ return val;
+ }
+}
+
+// TODO: Template to allow keeping ncols in registers if they fit
+static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
+ float * __restrict__ dst,
+ float * __restrict__ tmp_maxs,
+ float * __restrict__ tmp_sums,
+ const soft_max_params p) {
+ namespace cg = cooperative_groups;
+
+ const cg::grid_group g = cg::this_grid();
+
+ const int tid = threadIdx.x;
+ const int col_start = blockIdx.x * blockDim.x + tid;
+ const int n_elem_per_thread = 4;
+
+ float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
+ float local_max = -INFINITY;
+ const int step_size = gridDim.x * blockDim.x;
+
+ // Compute thread-local max
+ for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ const int idx = col + i * step_size;
+ local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
+ }
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ local_max = fmaxf(local_max, local_vals[i]);
+ }
+ col += step_size * n_elem_per_thread;
+ }
+
+ // Compute CTA-level max
+ local_max = two_stage_warp_reduce_max(local_max);
+
+ // Store CTA-level max to GMEM
+ if (tid == 0) {
+ tmp_maxs[blockIdx.x] = local_max;
+ }
+ g.sync();
+
+ // Compute compute global max from CTA-level maxs
+ assert(gridDim.x < blockDim.x); // currently we only support this case
+ if (tid < gridDim.x) {
+ local_max = tmp_maxs[tid];
+ } else {
+ local_max = -INFINITY;
+ }
+ local_max = two_stage_warp_reduce_max(local_max);
+
+ // Compute softmax dividends, accumulate divisor
+ float tmp_expf = 0.0f;
+ for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ const int idx = col + i * step_size;
+ local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
+ }
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ const int idx = col + i * step_size;
+ if (idx < p.ncols) {
+ const float tmp = expf(local_vals[i] - local_max);
+ tmp_expf += tmp;
+ dst[idx] = tmp;
+ }
+ }
+ col += step_size * n_elem_per_thread;
+ }
+
+ // Reduce divisor within CTA
+ tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+
+ // Store CTA-level sum to GMEM
+ if (tid == 0) {
+ tmp_sums[blockIdx.x] = tmp_expf;
+ }
+ g.sync();
+
+ // Compute global sum from CTA-level sums
+ if (tid < gridDim.x) {
+ tmp_expf = tmp_sums[tid];
+ } else {
+ tmp_expf = 0.0f;
+ }
+ tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+
+ // Divide dividend by global sum + store data
+ for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ const int idx = col + i * step_size;
+ local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
+ }
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ const int idx = col + i * step_size;
+ if (idx < p.ncols) {
+ dst[idx] = local_vals[i] / tmp_expf;
+ }
+ }
+ col += step_size * n_elem_per_thread;
+ }
+}
+
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
}
+__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
+ float * __restrict__ dst,
+ float * __restrict__ tmp_maxs,
+ float * __restrict__ tmp_sums,
+ const soft_max_params p)
+// We loop over all instead of parallelizing across gridDim.y as cooperative groups
+// currently only support synchronizing the complete grid if not launched as a cluster group
+// (which requires CC > 9.0)
+// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization
+// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
+{
+ for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
+ soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,
+ tmp_sums, p);
+ }
+}
-template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
+template <typename T>
+static void soft_max_f32_cuda(const float * x,
+ const T * mask,
+ const float * sinks,
+ float * dst,
+ const soft_max_params & params,
+ cudaStream_t stream,
+ [[maybe_unused]] ggml_backend_cuda_context & ctx) {
int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols;
if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else {
- const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
+ // Parallelize across SMs for top-p/dist-sampling
+ // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
+ // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
+ if (ggml_cuda_info().devices[id].supports_cooperative_launch &&
+ ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr &&
+ params.scale == 1.0f && params.max_bias == 0.0f) {
+ ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
+ ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
+
+ void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,
+ (void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(¶ms) };
+ CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
+ dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
+ dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
+ } else {
+ const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
+ soft_max_f32<false, 0, 0>
+ <<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
+ }
}
}
params.m1 = m1;
if (use_f16) {
- soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
+ soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
} else {
- soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
+ soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
}
}
--- /dev/null
+#include "argsort.cuh"
+#include "top-k.cuh"
+
+#ifdef GGML_CUDA_USE_CUB
+# include <cub/cub.cuh>
+# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
+# include <cuda/iterator>
+# define CUB_TOP_K_AVAILABLE
+using namespace cub;
+# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
+#endif // GGML_CUDA_USE_CUB
+
+#ifdef CUB_TOP_K_AVAILABLE
+
+static void top_k_cub(ggml_cuda_pool & pool,
+ const float * src,
+ int * dst,
+ const int ncols,
+ const int k,
+ cudaStream_t stream) {
+ auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
+ cuda::execution::output_ordering::unsorted);
+ auto stream_env = cuda::stream_ref{ stream };
+ auto env = cuda::std::execution::env{ stream_env, requirements };
+
+ auto indexes_in = cuda::make_counting_iterator(0);
+
+ size_t temp_storage_bytes = 0;
+ DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
+ env);
+
+ ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
+ void * d_temp_storage = temp_storage_alloc.get();
+
+ DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
+ ncols, k, env);
+}
+
+#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
+
+static int next_power_of_2(int x) {
+ int n = 1;
+ while (n < x) {
+ n *= 2;
+ }
+ return n;
+}
+
+#endif // CUB_TOP_K_AVAILABLE
+
+void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *) src0->data;
+ int * dst_d = (int *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ // are these asserts truly necessary?
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_I32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+ const int64_t k = dst->ne[0];
+ ggml_cuda_pool & pool = ctx.pool();
+#ifdef CUB_TOP_K_AVAILABLE
+ // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
+ // https://github.com/NVIDIA/cccl/issues/6391
+ // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
+ for (int i = 0; i < nrows; i++) {
+ top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
+ }
+#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
+ // Fall back to argsort + copy
+ const int ncols_pad = next_power_of_2(ncols);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+ ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
+ int * tmp_dst = temp_dst_alloc.get();
+
+ if (shared_mem > max_shared_mem || ncols > 1024) {
+ argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+ } else {
+ argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+ }
+ CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
+ cudaMemcpyDeviceToDevice, stream));
+#else // GGML_CUDA_USE_CUB
+ ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
+ int * tmp_dst = temp_dst_alloc.get();
+ argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+ CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
+ cudaMemcpyDeviceToDevice, stream));
+#endif
+}
--- /dev/null
+#include "common.cuh"
+
+void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
+#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
+#define cudaDeviceGetAttribute hipDeviceGetAttribute
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
+#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel
#define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
+#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost