]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml-cuda: gdn use shared mem for HIP (llama/20366)
authoruvos <redacted>
Wed, 11 Mar 2026 05:06:19 +0000 (06:06 +0100)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
Suggested-by: Aman Gupta <redacted>
ggml/src/ggml-cuda/gated_delta_net.cu

index d8e81114559f0400ca4729dcdeb68e1cee5bd091..c249bbc86d5a8f83ca11dbba98740de971a327a9 100644 (file)
@@ -2,28 +2,29 @@
 #include "ggml-cuda/common.cuh"
 
 template <int S_v, bool KDA>
-__global__ void gated_delta_net_cuda(const float * q,
-                                     const float * k,
-                                     const float * v,
-                                     const float * g,
-                                     const float * beta,
-                                     const float * curr_state,
-                                     float *       dst,
-                                     int64_t       H,
-                                     int64_t       n_tokens,
-                                     int64_t       n_seqs,
-                                     int64_t       sq1,
-                                     int64_t       sq2,
-                                     int64_t       sq3,
-                                     int64_t       sv1,
-                                     int64_t       sv2,
-                                     int64_t       sv3,
-                                     int64_t       sb1,
-                                     int64_t       sb2,
-                                     int64_t       sb3,
-                                     int64_t       rq1,
-                                     int64_t       rq3,
-                                     float         scale) {
+__global__ void __launch_bounds__(S_v, 1)
+gated_delta_net_cuda(const float * q,
+                     const float * k,
+                     const float * v,
+                     const float * g,
+                     const float * beta,
+                     const float * curr_state,
+                     float *       dst,
+                     const int64_t H,
+                     const int64_t n_tokens,
+                     const int64_t n_seqs,
+                     const int64_t sq1,
+                     const int64_t sq2,
+                     const int64_t sq3,
+                     const int64_t sv1,
+                     const int64_t sv2,
+                     const int64_t sv3,
+                     const int64_t sb1,
+                     const int64_t sb2,
+                     const int64_t sb3,
+                     const int64_t rq1,
+                     const int64_t rq3,
+                     const float   scale) {
     const int64_t h_idx    = blockIdx.x;
     const int64_t sequence = blockIdx.y;
     const int     col      = threadIdx.x;  // each thread owns one column
@@ -40,8 +41,14 @@ __global__ void gated_delta_net_cuda(const float * q,
     curr_state += state_offset;
     attn_data += (sequence * n_tokens * H + h_idx) * S_v;
 
-    // Load state column into registers
+    // GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
+    // TODO: check optimal path for RDNA1 and RDNA2 devices.
+#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
+    extern __shared__ float s_shared[];
+    float * s = s_shared + col * S_v;
+#else
     float s[S_v];
+#endif
 #pragma unroll
     for (int i = 0; i < S_v; i++) {
         s[i] = curr_state[i * S_v + col];
@@ -114,6 +121,15 @@ __global__ void gated_delta_net_cuda(const float * q,
     }
 }
 
+static size_t calculate_smem(const int sv, int cc)
+{
+    size_t smem = 0;
+    if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
+        smem = sv * sv * sizeof(float);
+    }
+    return smem;
+}
+
 template <bool KDA>
 static void launch_gated_delta_net(
         const float * q_d, const float * k_d, const float * v_d,
@@ -129,25 +145,36 @@ static void launch_gated_delta_net(
     dim3 grid_dims(H, n_seqs, 1);
     dim3 block_dims(S_v, 1, 1);
 
+    int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
     switch (S_v) {
-        case 32:
-            gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
+        case 32: {
+            constexpr int sv = 32;
+            size_t smem = calculate_smem(sv, cc);
+            gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
                 sb1, sb2, sb3, rq1, rq3, scale);
             break;
-        case 64:
-            gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
+        }
+        case 64: {
+            constexpr int sv = 64;
+            size_t smem = calculate_smem(sv, cc);
+            gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
                 sb1, sb2, sb3, rq1, rq3, scale);
             break;
-        case 128:
-            gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
+        }
+        case 128: {
+            constexpr int sv = 128;
+            size_t smem = calculate_smem(sv, cc);
+            gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
                 sb1, sb2, sb3, rq1, rq3, scale);
             break;
+        }
         default:
             GGML_ABORT("fatal error");
             break;