]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: GDN hide memory latency (llama/20537)
authorAman Gupta <redacted>
Mon, 16 Mar 2026 03:41:45 +0000 (11:41 +0800)
committerGeorgi Gerganov <redacted>
Sat, 28 Mar 2026 11:39:09 +0000 (13:39 +0200)
src/ggml-cuda/gated_delta_net.cu

index 1ce6d5f31b5780774e567b365f89b158997c0a62..6b44bec731746fdd393147079fdbfe5fe8107d51 100644 (file)
@@ -1,7 +1,8 @@
 #include "gated_delta_net.cuh"
 
 template <int S_v, bool KDA>
-__global__ void gated_delta_net_cuda(const float * q,
+__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
+gated_delta_net_cuda(const float * q,
                                      const float * k,
                                      const float * v,
                                      const float * g,
@@ -38,7 +39,7 @@ __global__ void gated_delta_net_cuda(const float * q,
 
     const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
     state += state_offset;
-    curr_state += state_offset;
+    curr_state += state_offset + col * S_v;
     attn_data += (sequence * n_tokens * H + h_idx) * S_v;
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
@@ -46,10 +47,11 @@ __global__ void gated_delta_net_cuda(const float * q,
     constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
     float         s_shard[rows_per_lane];
     // state is stored transposed: M[col][i] = S[i][col], row col is contiguous
+
 #pragma unroll
     for (int r = 0; r < rows_per_lane; r++) {
         const int i = r * warp_size + lane;
-        s_shard[r]  = curr_state[col * S_v + i];
+        s_shard[r]  = curr_state[i];
     }
 
     for (int t = 0; t < n_tokens; t++) {
@@ -63,6 +65,16 @@ __global__ void gated_delta_net_cuda(const float * q,
 
         const float beta_val = *beta_t;
 
+        // Cache k and q in registers
+        float k_reg[rows_per_lane];
+        float q_reg[rows_per_lane];
+#pragma unroll
+        for (int r = 0; r < rows_per_lane; r++) {
+            const int i = r * warp_size + lane;
+            k_reg[r] = k_t[i];
+            q_reg[r] = q_t[i];
+        }
+
         if constexpr (!KDA) {
             const float g_val = expf(*g_t);
 
@@ -70,8 +82,7 @@ __global__ void gated_delta_net_cuda(const float * q,
             float kv_shard = 0.0f;
 #pragma unroll
             for (int r = 0; r < rows_per_lane; r++) {
-                const int i = r * warp_size + lane;
-                kv_shard += s_shard[r] * k_t[i];
+                kv_shard += s_shard[r] * k_reg[r];
             }
             float kv_col = warp_reduce_sum<warp_size>(kv_shard);
 
@@ -83,9 +94,8 @@ __global__ void gated_delta_net_cuda(const float * q,
             float attn_partial = 0.0f;
 #pragma unroll
             for (int r = 0; r < rows_per_lane; r++) {
-                const int i = r * warp_size + lane;
-                s_shard[r]  = g_val * s_shard[r] + k_t[i] * delta_col;
-                attn_partial += s_shard[r] * q_t[i];
+                s_shard[r]  = g_val * s_shard[r] + k_reg[r] * delta_col;
+                attn_partial += s_shard[r] * q_reg[r];
             }
 
             float attn_col = warp_reduce_sum<warp_size>(attn_partial);
@@ -99,7 +109,7 @@ __global__ void gated_delta_net_cuda(const float * q,
 #pragma unroll
             for (int r = 0; r < rows_per_lane; r++) {
                 const int i = r * warp_size + lane;
-                kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
+                kv_shard += expf(g_t[i]) * s_shard[r] * k_reg[r];
             }
 
             float kv_col = warp_reduce_sum<warp_size>(kv_shard);
@@ -113,8 +123,8 @@ __global__ void gated_delta_net_cuda(const float * q,
 #pragma unroll
             for (int r = 0; r < rows_per_lane; r++) {
                 const int i = r * warp_size + lane;
-                s_shard[r]  = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
-                attn_partial += s_shard[r] * q_t[i];
+                s_shard[r]  = expf(g_t[i]) * s_shard[r] + k_reg[r] * delta_col;
+                attn_partial += s_shard[r] * q_reg[r];
             }
 
             float attn_col = warp_reduce_sum<warp_size>(attn_partial);