* Add support for CUMSUM and TRI for CUDA.
* Minor optimizations.
* Correct warp_prefix_inclusive_sum in float2 variant to return float2
* Optimize TRI
* Whitespace
* Fix strides.
* Implement double loop
* Whitespace
* Fix HIP compilation bugs
* Optimizations + big case performance tests
* Implement using CUB with fallback to custom kernel
* Remove error message.
* Fixes from code review
* Comment out CPU-unsupported F16/BF16 cases to fix CI
* Fine, you win :P
* Fix last cast, use NO_DEVICE_CODE and GGML_UNUSED_VARS
* Vary warp-size based on physical warp size
* Add GGML_UNUSED_VARS in tri as well
* Use constexpr and call prefix_inclusive with warp_size template param
* Update ggml/src/ggml-cuda/cumsum.cu
Co-authored-by: Johannes Gäßler <redacted>
* Apply suggestions from code review
Co-authored-by: Johannes Gäßler <redacted>
* Change to tid % warp_size
* Fix strides; hardcode mask; add ggml_lane_mask_t
* Missing renames, remove unused get_warp_mask(), explicit calls to ggml_cuda_info()
* Too hasty...
---------
Co-authored-by: Johannes Gäßler <redacted>
return x;
}
+template<typename T, int width = WARP_SIZE>
+static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
+ const int lane_id = threadIdx.x % width;
+#pragma unroll
+ for (int offset = 1; offset < width; offset <<= 1) {
+ const T t = __shfl_up_sync(0xffffffff, x, offset, width);
+ if (lane_id >= offset) {
+ x += t;
+ }
+ }
+ return x;
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
+ const int lane_id = threadIdx.x % width;
+#pragma unroll
+ for (int offset = 1; offset < width; offset <<= 1) {
+ const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);
+ const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);
+ if (lane_id >= offset) {
+ a.x += t_x;
+ a.y += t_y;
+ }
+ }
+ return a;
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
+#ifdef FP16_AVAILABLE
+ const int lane_id = threadIdx.x % width;
+#pragma unroll
+ for (int offset = 1; offset < width; offset <<= 1) {
+ const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);
+ if (lane_id >= offset) {
+ a = __hadd2(a, t);
+ }
+ }
+ return a;
+
+#else
+ NO_DEVICE_CODE;
+ return a;
+#endif // FP16_AVAILABLE
+}
+
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE
--- /dev/null
+#include <algorithm>
+#include "cumsum.cuh"
+#include "convert.cuh"
+#include "ggml-cuda/common.cuh"
+#include "ggml.h"
+
+#ifdef GGML_CUDA_USE_CUB
+# include <cub/device/device_scan.cuh>
+#endif // GGML_CUDA_USE_CUB
+
+template<typename T, int BLOCK_SIZE>
+static __global__ void cumsum_cub_kernel(
+ const T * __restrict__ src,
+ T * __restrict__ dst,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t s01, const int64_t s02, const int64_t s03,
+ const int64_t s1, const int64_t s2, const int64_t s3) {
+#ifdef GGML_CUDA_USE_CUB
+ using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
+
+ __shared__ typename BlockScan::TempStorage temp_storage;
+ __shared__ T block_carry; // carry from previous tile
+
+ const int tid = threadIdx.x;
+
+ const int64_t i1 = blockIdx.x;
+ const int64_t i2 = blockIdx.y;
+ const int64_t i3 = blockIdx.z;
+
+ if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
+ return;
+ }
+
+ const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
+ T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
+
+ if (tid == 0) {
+ block_carry = 0;
+ }
+ __syncthreads();
+
+ for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
+ int64_t idx = start + tid;
+ T x = (idx < ne00) ? src_row[idx] : T(0);
+
+ T inclusive;
+ T block_total;
+ BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
+
+ __syncthreads();
+
+ T final_val = inclusive + block_carry;
+
+ // store result
+ if (idx < ne00) {
+ dst_row[idx] = final_val;
+ }
+
+ __syncthreads();
+
+ if (tid == 0) {
+ block_carry += block_total;
+ }
+
+ __syncthreads();
+ }
+#else
+ NO_DEVICE_CODE;
+#endif // GGML_CUDA_USE_CUB
+}
+
+// Fallback kernel implementation (original)
+template<typename T>
+static __global__ void cumsum_kernel(
+ const T * src, T * dst,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
+ const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) {
+
+ GGML_UNUSED_VARS(s00, s0);
+
+ const int tid = threadIdx.x;
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ const int lane = tid % warp_size;
+ const int warp = tid / warp_size;
+ const int warps_per_block = blockDim.x / warp_size;
+
+ extern __shared__ float smem[];
+ float * s_vals = smem;
+ float * s_warp_sums = smem + blockDim.x;
+ float * s_carry = smem + blockDim.x + warps_per_block;
+ float * s_chunk_total = s_carry + 1;
+
+ // Initialize carry
+ if (tid == 0) {
+ *s_carry = 0.0f;
+ }
+ __syncthreads();
+
+ const int64_t i3 = blockIdx.z;
+ const int64_t i2 = blockIdx.y;
+ const int64_t i1 = blockIdx.x;
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
+ }
+
+ const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
+ T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
+
+ for (int64_t start = 0; start < ne00; start += blockDim.x) {
+ int64_t idx = start + tid;
+ float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
+
+ // 1. Warp inclusive scan
+ val = warp_prefix_inclusive_sum<T, warp_size>(val);
+ s_vals[tid] = val;
+
+ // Store warp total
+ if (lane == warp_size - 1) {
+ s_warp_sums[warp] = val;
+ }
+ __syncthreads();
+
+ // 2. Exclusive scan of warp sums (warp 0 only)
+ if (warp == 0) {
+ float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
+ float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
+ if (tid < warps_per_block) {
+ s_warp_sums[tid] = inc - w; // exclusive sum
+ }
+ if (tid == warps_per_block - 1) {
+ *s_chunk_total = inc; // total sum of this chunk
+ }
+ }
+ __syncthreads();
+
+ float carry = *s_carry;
+ float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
+ if (idx < ne00) {
+ dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
+ }
+ __syncthreads();
+
+ // Update carry for next chunk
+ if (tid == 0) {
+ *s_carry += *s_chunk_total;
+ }
+ __syncthreads();
+ }
+}
+
+template<typename T>
+static void cumsum_cuda(
+ const T * src, T * dst,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
+ const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
+ cudaStream_t stream) {
+
+ const size_t type_size = sizeof(T);
+ bool use_cub = false;
+#ifdef GGML_CUDA_USE_CUB
+ // Check if we can use CUB (data must be contiguous along innermost dimension)
+ const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
+
+ if (is_contiguous) {
+ use_cub = true;
+ }
+#endif // GGML_CUDA_USE_CUB
+ dim3 grid_dims(ne01, ne02, ne03);
+ const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
+ const int warp_size = info.warp_size;
+ const int num_warps = (ne00 + warp_size - 1) / warp_size;
+ int block_size = num_warps * warp_size;
+ block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
+ dim3 block_dims(block_size, 1, 1);
+ const int warps_per_block = block_size / warp_size;
+ const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
+
+ if (use_cub) {
+ cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ } else {
+ cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ }
+}
+
+void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == dst->type);
+ switch(src0->type) {
+ case GGML_TYPE_F32:
+ {
+ cumsum_cuda(
+ (const float *)src0->data, (float *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ stream
+ );
+ } break;
+ // We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
+ /*case GGML_TYPE_F16:
+ {
+ cumsum_cuda(
+ (const half *)src0->data, (half *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ stream
+ );
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ cumsum_cuda(
+ (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ stream
+ );
+ } break;*/
+ default:
+ GGML_ABORT("fatal error");
+ }
+}
--- /dev/null
+#include "common.cuh"
+
+#define CUDA_CUMSUM_BLOCK_SIZE 256
+
+void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml-cuda/solve_tri.cuh"
+#include "ggml-cuda/tri.cuh"
+#include "ggml-cuda/cumsum.cuh"
#include "ggml.h"
#include <algorithm>
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
+ case GGML_OP_CUMSUM:
+ ggml_cuda_op_cumsum(ctx, dst);
+ break;
+ case GGML_OP_TRI:
+ ggml_cuda_op_tri(ctx, dst);
+ break;
case GGML_OP_RWKV_WKV6:
ggml_cuda_op_rwkv_wkv6(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
+ case GGML_OP_CUMSUM:
+ case GGML_OP_TRI:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
--- /dev/null
+#include "common.cuh"
+#include "convert.cuh"
+#include "tri.cuh"
+#include "ggml.h"
+
+template<typename T, bool prefix_keep, int add_to_split>
+static __global__ void tri_kernel(
+ const T * src, T * dst,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
+ const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
+ const int64_t i3 = blockIdx.z;
+ const int64_t i2 = blockIdx.y;
+ const int64_t i1 = blockIdx.x;
+ const int64_t split_point = i1 + add_to_split;
+
+ GGML_UNUSED_VARS(nb00, nb0);
+
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
+ }
+
+ const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03;
+ T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3;
+
+ if constexpr (prefix_keep) {
+ for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
+ dst_row[i0] = src_row[i0];
+ }
+ for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
+ dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
+ }
+ } else {
+ for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
+ dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
+ }
+ for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
+ dst_row[i0] = src_row[i0];
+ }
+ }
+}
+
+template<typename T>
+static void tri_cuda(
+ const T * src, T * dst,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
+ const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
+ const ggml_tri_type ttype,
+ cudaStream_t stream) {
+
+ dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
+ dim3 grid_dims(ne01, ne02, ne03);
+ const size_t type_size = sizeof(T);
+
+ const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
+ const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
+
+ if (prefix_keep) {
+ if (add_to_split == 0) {
+ tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ } else { // only 0 and 1 supported
+ tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ }
+ } else {
+ if (add_to_split == 0) {
+ tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ } else {
+ tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ }
+ }
+}
+
+void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ cudaStream_t stream = ctx.stream();
+
+ const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));
+
+ GGML_ASSERT(src0->type == dst->type);
+
+ switch(src0->type) {
+ case GGML_TYPE_F32:
+ {
+ tri_cuda(
+ (const float *)src0->data, (float *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ ttype, stream
+ );
+ } break;
+ case GGML_TYPE_F16:
+ {
+ tri_cuda(
+ (const half *)src0->data, (half *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ ttype, stream
+ );
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ tri_cuda(
+ (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ ttype, stream
+ );
+ } break;
+ default:
+ GGML_ABORT("fatal error");
+ }
+}
--- /dev/null
+#include "common.cuh"
+
+#define CUDA_TRI_BLOCK_SIZE 256
+
+void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));
+ test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
+ test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
+ test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
+
+ test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));
+ test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 16, 5, 4 }));
+ test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20000, 10, 4, 1 }));
+
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {