]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: larger SRAM reads for tile FA, AMD FP16 dot (llama/15927)
authorJohannes Gäßler <redacted>
Thu, 11 Sep 2025 19:19:58 +0000 (21:19 +0200)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* CUDA: larger SRAM reads for tile FA, AMD FP16 dot

* fix logic for availability of v_dot2_f32_f16

src/ggml-cuda/common.cuh
src/ggml-cuda/fattn-tile.cu
src/ggml-cuda/vendors/hip.h

index 394595be0eada63bb91e5b5b84fd718485e9061c..b0feea362380b4fecd5e222237ce72486c30a0ea 100644 (file)
@@ -555,7 +555,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
 }
 
 static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
-#if defined(GGML_USE_HIP) && defined(GCN)
+#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
     asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
 #else
 #ifdef FAST_FP16_AVAILABLE
@@ -567,7 +567,21 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
     acc += tmpv.x * tmpu.x;
     acc += tmpv.y * tmpu.y;
 #endif // FAST_FP16_AVAILABLE
-#endif // defined(GGML_USE_HIP) && defined(GCN)
+#endif // defined(GGML_USE_HIP) && (defined(RDNA2)  || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
+}
+
+// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
+template <int nbytes>
+static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
+    if constexpr (nbytes == 4) {
+        *(int *) dst = *(const int *) src;
+    } else if constexpr (nbytes == 8) {
+        *(int2 *) dst = *(const int2 *) src;
+    } else if constexpr (nbytes == 16) {
+        *(int4 *) dst = *(const int4 *) src;
+    } else {
+        static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
+    }
 }
 
 static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
index 64f7d4a1a14701851fac6f70c8f2aa5a5e4db53e..c6a399ce5d791ca397232ac88621e920985f287e 100644 (file)
@@ -8,11 +8,14 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
     if (GGML_CUDA_CC_IS_AMD(cc)) {
         switch (D) {
             case 64:
-                return ncols <= 16 ? 32 : 64;
+                return 64;
             case 128:
-                return ncols <= 16 ? 64 : warp_size;
             case 256:
-                return 64;
+                if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
+                    return ncols <= 16 ? 64 : 32;
+                } else {
+                    return 64;
+                }
             default:
                 GGML_ABORT("fatal error");
                 return -1;
@@ -41,17 +44,26 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
             GGML_ABORT("fatal error");
             return -1;
     }
+    GGML_UNUSED(warp_size);
 }
 
 static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
 #ifdef GGML_USE_HIP
     switch (D) {
         case 64:
-            return ncols <= 16 ? 32 : 64;
+            return 64;
         case 128:
-            return ncols <= 16 ? 64 : warp_size;
+#if defined(GCN) || defined(CDNA)
+            return ncols <= 16 ? 64 : 32;
+#else
+            return 64;
+#endif // defined(GCN) || defined(CDNA)
         case 256:
+#if defined(GCN) || defined(CDNA)
+            return ncols <= 16 ? 64 : 32;
+#else
             return 64;
+#endif // defined(GCN) || defined(CDNA)
         default:
             return -1;
     }
@@ -88,9 +100,17 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
         case 64:
             return 64;
         case 128:
-            return ncols <= 16 ? 2*warp_size : 128;
+#if defined(GCN) || defined(CDNA)
+            return ncols <= 16 ? 64 : 128;
+#else
+            return 64;
+#endif // defined(GCN) || defined(CDNA)
         case 256:
-            return ncols <= 16 ? 128 : 2*warp_size;
+#if defined(GCN) || defined(CDNA)
+            return ncols <= 16 ? 64 : 128;
+#else
+            return ncols <= 16 ? 64 : 256;
+#endif // defined(GCN) || defined(CDNA)
         default:
             return -1;
     }
@@ -196,14 +216,21 @@ static __global__ void flash_attn_tile(
 
     const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
+#if defined(GGML_USE_HIP)
+    constexpr int cpy_nb = 16;
+#else
+    constexpr int cpy_nb = 8;
+#endif // defined(GGML_USE_HIP) && defined(GCN)
+    constexpr int cpy_ne = cpy_nb / 4;
+
     __shared__ float KQ[ncols][kq_stride];
 #ifdef FAST_FP16_AVAILABLE
     __shared__ half2 Q_tmp[ncols][D/2];
-    __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1)]; // Padded to avoid memory bank conflicts.
+    __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
     half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
 #else
     __shared__ float Q_tmp[ncols][D];
-    __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1)]; // Padded to avoid memory bank conflicts.
+    __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
     float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
     float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
 #endif // FAST_FP16_AVAILABLE
@@ -256,11 +283,11 @@ static __global__ void flash_attn_tile(
                 for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
                     const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
 #ifdef FAST_FP16_AVAILABLE
-                    KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1 + threadIdx.x] = tmp_h2;
+                    KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2;
 #else
                     const float2 tmp_f2 = __half22float2(tmp_h2);
-                    KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1             + threadIdx.x] = tmp_f2.x;
-                    KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
+                    KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1             + threadIdx.x] = tmp_f2.x;
+                    KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
 #endif // FAST_FP16_AVAILABLE
                 }
             }
@@ -269,14 +296,14 @@ static __global__ void flash_attn_tile(
 
 #ifdef FAST_FP16_AVAILABLE
 #pragma unroll
-            for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; ++k_KQ_1) {
-                half2 K_k[kq_stride/warp_size];
-                half2 Q_k[ncols/nwarps];
+            for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
+                half2 K_k[kq_stride/warp_size][cpy_ne];
+                half2 Q_k[ncols/nwarps][cpy_ne];
 #else
 #pragma unroll
-            for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; ++k_KQ_1) {
-                float K_k[kq_stride/warp_size];
-                float Q_k[ncols/nwarps];
+            for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
+                float K_k[kq_stride/warp_size][cpy_ne];
+                float Q_k[ncols/nwarps][cpy_ne];
 #endif // FAST_FP16_AVAILABLE
 
 #pragma unroll
@@ -284,9 +311,9 @@ static __global__ void flash_attn_tile(
                     const int i_KQ = i_KQ_0 + threadIdx.x;
 
 #ifdef FAST_FP16_AVAILABLE
-                    K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1];
+                    ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
 #else
-                    K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch   + 1) + k_KQ_1];
+                    ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch   + cpy_ne) + k_KQ_1]);
 #endif // FAST_FP16_AVAILABLE
                 }
 #pragma unroll
@@ -294,9 +321,9 @@ static __global__ void flash_attn_tile(
                     const int j_KQ = j_KQ_0 + threadIdx.y;
 
 #ifdef FAST_FP16_AVAILABLE
-                    Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1];
+                    ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
 #else
-                    Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0   + k_KQ_1];
+                    ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0   + k_KQ_1]);
 #endif // FAST_FP16_AVAILABLE
                 }
 
@@ -304,7 +331,10 @@ static __global__ void flash_attn_tile(
                 for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
 #pragma unroll
                     for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
-                        ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
+#pragma unroll
+                        for (int k = 0; k < cpy_ne; ++k) {
+                            ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
+                        }
                     }
                 }
             }
@@ -345,14 +375,54 @@ static __global__ void flash_attn_tile(
             kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
 
             float kqsum_add = 0.0f;
+            if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) {
 #pragma unroll
-            for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
-                const int i = i0 + threadIdx.x;
+                for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) {
+                    const int i = i0 + 4*threadIdx.x;
 
-                const float diff = KQ[j][i] - kqmax[j0/nwarps];
-                const float val = expf(diff);
-                kqsum_add += val;
-                KQ[j][i] = val;
+                    float4 val = *(const float4 *) &KQ[j][i];
+                    val.x = expf(val.x - kqmax[j0/nwarps]);
+                    val.y = expf(val.y - kqmax[j0/nwarps]);
+                    val.z = expf(val.z - kqmax[j0/nwarps]);
+                    val.w = expf(val.w - kqmax[j0/nwarps]);
+                    kqsum_add += val.x + val.y + val.z + val.w;
+
+#ifdef FAST_FP16_AVAILABLE
+                    const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)};
+                    ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
+#else
+                    ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
+#endif // FAST_FP16_AVAILABLE
+                }
+            } else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) {
+#pragma unroll
+                for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) {
+                    const int i = i0 + 2*threadIdx.x;
+
+                    float2 val = *(const float2 *) &KQ[j][i];
+                    val.x = expf(val.x - kqmax[j0/nwarps]);
+                    val.y = expf(val.y - kqmax[j0/nwarps]);
+                    kqsum_add += val.x + val.y;
+#ifdef FAST_FP16_AVAILABLE
+                    const half2 tmp = make_half2(val.x, val.y);
+                    ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
+#else
+                    ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
+#endif // FAST_FP16_AVAILABLE
+                }
+            } else {
+                for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
+                    const int i = i0 + threadIdx.x;
+
+                    const float diff = KQ[j][i] - kqmax[j0/nwarps];
+                    const float val = expf(diff);
+                    kqsum_add += val;
+#ifdef FAST_FP16_AVAILABLE
+                    ((half *) KQ[j])[i] = val;
+#else
+                    KQ[j][i] = val;
+#endif // FAST_FP16_AVAILABLE
+                }
             }
             kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
 
@@ -419,8 +489,7 @@ static __global__ void flash_attn_tile(
                     const int j = j0 + threadIdx.y;
 
 #ifdef FAST_FP16_AVAILABLE
-                    const float tmp = KQ[j][k0 + k1];
-                    KQ_k[j0/nwarps] = make_half2(tmp, tmp);
+                    KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]);
 #else
                     KQ_k[j0/nwarps] = KQ[j][k0 + k1];
 #endif // FAST_FP16_AVAILABLE
index c6a33d5de310f632523ab975c9060cfeb2b944ab..12bbee45566de393509d9c91400d3329d8161740 100644 (file)
 #define GCN
 #endif
 
+#if defined(__gfx900__) || defined(__gfx906__)
+#define GCN5
+#endif
+
+#if defined(__gfx803__)
+#define GCN4
+#endif
+
 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
 #define CDNA // For the entire family
 #endif