]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: add dynamic shared mem to softmax, refactor general usage (llama/14497)
authorAman Gupta <redacted>
Wed, 2 Jul 2025 23:45:11 +0000 (07:45 +0800)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
src/ggml-cuda/common.cuh
src/ggml-cuda/cross-entropy-loss.cu
src/ggml-cuda/mmq.cuh
src/ggml-cuda/softmax.cu
tests/test-backend-ops.cpp

index ea203550238254b6337fe22012b9b54796d92396..954f74d408f9fc3623d10aefb973434015392fb7 100644 (file)
@@ -175,6 +175,20 @@ static const char * cu_get_error_str(CUresult err) {
 #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
 #endif
 
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
+    do { \
+        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
+        const int id = ggml_cuda_get_device(); \
+        if (!shared_memory_limit_raised[id]) { \
+            CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
+            shared_memory_limit_raised[id] = true; \
+        } \
+    } while (0)
+#else
+#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+
 #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
 #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
 #else
index 0ce4afbb222bd22c59b6e2710adda9ae5b89d3f6..0c8b0819724e47c1f5970d5292e2c91c6179f114 100644 (file)
@@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
     ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
 
     if (nbytes_shared <= smpbo) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
-        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
-        if (!shared_memory_limit_raised[id]) {
-            CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
-            shared_memory_limit_raised[id] = true;
-        }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
         cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
     } else {
         cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
@@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
     const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
 
     if (nbytes_shared <= smpbo) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
-        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
-        if (!shared_memory_limit_raised[id]) {
-            CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
-            shared_memory_limit_raised[id] = true;
-        }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
         cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
     } else {
         cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
index 80baf459c15f2544d38ceb974cf719caeee3fd0f..9696a32046212479b19011c4e5f089dee13c5850 100644 (file)
@@ -3016,14 +3016,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
     const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
-    static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
-    if (!shared_memory_limit_raised[id]) {
-        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
-        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
-        shared_memory_limit_raised[id] = true;
-    }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+    CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
+    CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>),  nbytes_shared);
 
     const int nty  = (args.nrows_x   + mmq_y - 1) / mmq_y;
     const int ntx  = (args.ncols_dst + mmq_x - 1) / mmq_x;
index e0eb921d85d53d23493793711ff7b3e43547dd54..14543e978cf0f9f5a7d59a9741d65add2ebbaa02 100644 (file)
@@ -2,6 +2,7 @@
 #include "ggml.h"
 #include "softmax.cuh"
 #include <cstdint>
+#include <utility>
 
 template <typename T>
 static __device__ __forceinline__ float t2f32(T val) {
@@ -181,6 +182,37 @@ static __global__ void soft_max_back_f32(
     }
 }
 
+template<int... Ns, typename T>
+static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
+                             const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
+{
+    const int id       = ggml_cuda_get_device();
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+
+    auto launch_kernel = [=](auto I) -> bool {
+        constexpr int ncols = decltype(I)::value;
+        constexpr int block = (ncols > 1024 ? 1024 : ncols);
+
+        if (p.ncols == ncols) {
+            CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
+            soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, mask, dst, p);
+            return true;
+        }
+        return false;
+    };
+
+    // unary fold over launch_kernel
+    if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
+        return;
+    }
+
+    //default case
+    CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
+    soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
+}
+
+
 template<typename T>
 static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
     int nth = WARP_SIZE;
@@ -193,46 +225,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
     static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
 
 
-    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
-    if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
-        switch (ncols_x) {
-            case 32:
-                soft_max_f32<true,   32,   32><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, params);
-                break;
-            case 64:
-                soft_max_f32<true,   64,   64><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, params);
-                break;
-            case 128:
-                soft_max_f32<true,  128,  128><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, params);
-                break;
-            case 256:
-                soft_max_f32<true,  256,  256><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, params);
-                break;
-            case 512:
-                soft_max_f32<true,  512,  512><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, params);
-                break;
-            case 1024:
-                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, params);
-                break;
-            case 2048:
-                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, params);
-                break;
-            case 4096:
-                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, params);
-                break;
-            default:
-                soft_max_f32<true,    0,    0><<<block_nums, block_dims, nbytes_shared, stream>>>
-                    (x, mask, dst, params);
-                break;
-        }
+    const int id       = ggml_cuda_get_device();
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+
+
+    if (nbytes_shared <= smpbo) {
+        launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
     } else {
         const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
         soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
index 51bbcf92c9054ba755d9c59515029b9d48c618a5..5c3db854cb498bc843260aebe3e9fcd8a1c65f6f 100644 (file)
@@ -4932,6 +4932,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
 
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));