]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: add softmax broadcast (#14475)
authorAman Gupta <redacted>
Wed, 2 Jul 2025 12:34:24 +0000 (20:34 +0800)
committerGeorgi Gerganov <redacted>
Wed, 2 Jul 2025 12:48:33 +0000 (15:48 +0300)
* CUDA: add softmax broadcast

* Pass by const ref

* Review: Use blockDims for indexing, remove designated initializers

* Add TODO for noncontigous input/output

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/softmax.cu

index 186492be3402787d49eba3e0044c81856071555e..0fa7fb953d815b3d5b8d5787db336c165cae5a06 100644 (file)
@@ -3329,13 +3329,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_DIAG_MASK_INF:
             return true;
         case GGML_OP_SOFT_MAX:
-            // TODO: support batching
-            if (op->src[0]->ne[3] != 1) {
-                return false;
-            }
-            // TODO: support broadcast
-            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
-            return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
+            return true;
         case GGML_OP_SOFT_MAX_BACK: {
             float max_bias = 0.0f;
             memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
index aac6e0999880a529e63a8a796e1fbcbffa60e57f..e0eb921d85d53d23493793711ff7b3e43547dd54 100644 (file)
@@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
     return __half2float(val);
 }
 
+struct soft_max_params {
+
+    int64_t nheads;
+    uint32_t n_head_log2;
+    int64_t ncols;
+    int64_t nrows_x;
+    int64_t nrows_y;
+    int64_t ne00;
+    int64_t ne01;
+    int64_t ne02;
+    int64_t ne03;
+    int64_t nb11;
+    int64_t nb12;
+    int64_t nb13;
+
+    int64_t ne12;
+    int64_t ne13;
+    float scale;
+    float max_bias;
+    float m0;
+    float m1;
+};
+
 // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
 // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
 #ifdef __clang__
@@ -21,16 +44,24 @@ __device__ float __forceinline__ t2f32<half>(half val) {
 #endif // __clang__
 template <bool use_shared, int ncols_template, int block_size_template, typename T>
 static __global__ void soft_max_f32(
-        const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
-        const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
-    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
+        const float * x, const T * mask, float * dst, const soft_max_params p) {
+    const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
 
     const int tid  = threadIdx.x;
-    const int rowx = blockIdx.x;
-    const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
+
+    const int64_t i03 = blockIdx.z;
+    const int64_t i02 = blockIdx.y;
+    const int64_t i01 = blockIdx.x;
+
+    //TODO: noncontigous inputs/outputs
+    const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
+
+    const int64_t i11 = i01;
+    const int64_t i12 = i02 % p.ne12;
+    const int64_t i13 = i03 % p.ne13;
 
     x    += int64_t(rowx)*ncols;
-    mask += int64_t(rowy)*ncols * (mask != nullptr);
+    mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
     dst  += int64_t(rowx)*ncols;
 
     const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
@@ -38,7 +69,7 @@ static __global__ void soft_max_f32(
     const int warp_id = threadIdx.x / WARP_SIZE;
     const int lane_id = threadIdx.x % WARP_SIZE;
 
-    const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
 
     extern __shared__ float data_soft_max_f32[];
     float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -55,7 +86,7 @@ static __global__ void soft_max_f32(
             break;
         }
 
-        const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
+        const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
 
         vals[col] = val;
         max_val = max(max_val, val);
@@ -151,63 +182,60 @@ static __global__ void soft_max_back_f32(
 }
 
 template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
     int nth = WARP_SIZE;
+    const int64_t ncols_x = params.ncols;
+
     while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
     const dim3 block_dims(nth,     1, 1);
-    const dim3 block_nums(nrows_x, 1, 1);
+    const dim3 block_nums(params.ne01, params.ne02, params.ne03);
     const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
     static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
 
-    const uint32_t n_head      = nrows_x/nrows_y;
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
     // FIXME: this limit could be raised by ~2-4x on Ampere or newer
     if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
         switch (ncols_x) {
             case 32:
                 soft_max_f32<true,   32,   32><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                    (x, mask, dst, params);
                 break;
             case 64:
                 soft_max_f32<true,   64,   64><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                    (x, mask, dst, params);
                 break;
             case 128:
                 soft_max_f32<true,  128,  128><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                    (x, mask, dst, params);
                 break;
             case 256:
                 soft_max_f32<true,  256,  256><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                    (x, mask, dst, params);
                 break;
             case 512:
                 soft_max_f32<true,  512,  512><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                    (x, mask, dst, params);
                 break;
             case 1024:
                 soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                    (x, mask, dst, params);
                 break;
             case 2048:
                 soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                    (x, mask, dst, params);
                 break;
             case 4096:
                 soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                    (x, mask, dst, params);
                 break;
             default:
                 soft_max_f32<true,    0,    0><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                    (x, mask, dst, params);
                 break;
         }
     } else {
         const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+        soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
     }
 }
 
@@ -235,10 +263,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
 
-    const int64_t ne00    = src0->ne[0];
     const int64_t nrows_x = ggml_nrows(src0);
     const int64_t nrows_y = src0->ne[1];
 
+    const int64_t ne00 = src0->ne[0];
+
     float scale    = 1.0f;
     float max_bias = 0.0f;
 
@@ -247,10 +276,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
+    const int64_t nb11 = src1 ? src1->nb[1] : 1;
+    const int64_t nb12 = src1 ? src1->nb[2] : 1;
+    const int64_t nb13 = src1 ? src1->nb[3] : 1;
+
+    const int64_t ne12 = src1 ? src1->ne[2] : 1;
+    const int64_t ne13 = src1 ? src1->ne[3] : 1;
+
+    const uint32_t n_head      = src0->ne[2];
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+
+    soft_max_params params = {};
+    params.nheads = src0->ne[2];
+    params.n_head_log2 = n_head_log2;
+    params.ncols = ne00;
+    params.nrows_x = nrows_x;
+    params.nrows_y = nrows_y;
+    params.ne00 = src0->ne[0];
+    params.ne01 = src0->ne[1];
+    params.ne02 = src0->ne[2];
+    params.ne03 = src0->ne[3];
+    params.nb11 = nb11;
+    params.nb12 = nb12;
+    params.nb13 = nb13;
+    params.ne12 = ne12;
+    params.ne13 = ne13;
+    params.scale = scale;
+    params.max_bias = max_bias;
+    params.m0 = m0;
+    params.m1 = m1;
+
     if (use_f16) {
-        soft_max_f32_cuda(src0_d, (const half  *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+        soft_max_f32_cuda(src0_d, (const half  *) src1_d, dst_d, params, stream);
     } else {
-        soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+        soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
     }
 }