]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: generalize FP16 fattn vec kernel (llama/7061)
authorJohannes Gäßler <redacted>
Thu, 9 May 2024 12:32:02 +0000 (14:32 +0200)
committerGeorgi Gerganov <redacted>
Mon, 13 May 2024 08:02:26 +0000 (11:02 +0300)
* CUDA: generalize FP16 fattn vec kernel

* disable unsupported head sizes for AMD in test

* try AMD fix

* fix batch size 2-8

* partially revert changes

ggml-cuda/common.cuh
ggml-cuda/fattn.cu

index a4197f11ba779f52040e6932cb68e2689cc68bda..44e67e040e16a700924902aa884438f478ab63e2 100644 (file)
@@ -234,6 +234,97 @@ typedef float dfloat; // dequantize float
 typedef float2 dfloat2;
 #endif //GGML_CUDA_F16
 
+#if defined(GGML_USE_HIPBLAS)
+#define __CUDA_ARCH__ 1300
+
+#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
+    defined(__gfx1150__) || defined(__gfx1151__)
+#define RDNA3
+#endif
+
+#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
+    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
+#define RDNA2
+#endif
+
+#ifndef __has_builtin
+    #define __has_builtin(x) 0
+#endif
+
+typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
+typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
+static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
+    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+#if __has_builtin(__builtin_elementwise_sub_sat)
+    const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
+    return reinterpret_cast<const int &>(c);
+#else
+    int8x4_t c;
+    int16_t tmp;
+#pragma unroll
+    for (int i = 0; i < 4; i++) {
+        tmp = va[i] - vb[i];
+        if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
+        if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
+        c[i] = tmp;
+    }
+    return reinterpret_cast<int &>(c);
+#endif // __has_builtin(__builtin_elementwise_sub_sat)
+}
+
+static __device__ __forceinline__ int __vsub4(const int a, const int b) {
+    return __vsubss4(a, b);
+}
+
+static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
+    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+    unsigned int c;
+    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+    for (int i = 0; i < 4; ++i) {
+        vc[i] = va[i] == vb[i] ? 0xff : 0x00;
+    }
+    return c;
+}
+
+static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
+#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
+    c = __builtin_amdgcn_sdot4(a, b, c, false);
+#elif defined(RDNA3)
+    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
+#elif defined(__gfx1010__) || defined(__gfx900__)
+    int tmp1;
+    int tmp2;
+    asm("\n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        "
+        : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
+        : "v"(a), "v"(b)
+    );
+#else
+    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+    c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
+#endif
+    return c;
+}
+#endif // defined(GGML_USE_HIPBLAS)
+
+#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+
+#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+
+static bool fp16_mma_available(const int cc) {
+    return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
+}
+
 [[noreturn]]
 static __device__ void no_device_code(
     const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
@@ -275,16 +366,28 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 }
 
 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+#if FP16_AVAILABLE
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #pragma unroll
-   for (int mask = 16; mask > 0; mask >>= 1) {
-       a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
-   }
-   return a;
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
+        reinterpret_cast<half&>(a.x) +=  __low2half(a_other);
+        reinterpret_cast<half&>(a.y) += __high2half(a_other);
+    }
+    return a;
 #else
-   GGML_UNUSED(a);
-   NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
+    }
+    return a;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+#else
+    NO_DEVICE_CODE;
+    return a;
+#endif // FP16_AVAILABLE
 }
 
 static __device__ __forceinline__ float warp_reduce_max(float x) {
@@ -296,20 +399,21 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
 }
 
 static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+#if FP16_AVAILABLE
 
-#if CUDART_VERSION >= CUDART_HMAX
-    return __hmax(a, b);
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+    return __float2half(fmaxf(__half2float(a), __half2float(b)));
 #else
-    return __half2float(a) > __half2float(b) ? a : b;
-#endif // CUDART_VERSION >= CUDART_HMAX
+    return __hmax(a, b);
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
 
 #else
-    GGML_UNUSED(a);
-    GGML_UNUSED(b);
-    NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+   NO_DEVICE_CODE;
+   GGML_UNUSED(b);
+   return a;
+#endif // FP16_AVAILABLE
 }
+
 static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 
@@ -317,8 +421,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
     return __hmax2(a, b);
 #else
     half2 ret;
-    reinterpret_cast<half&>(ret.x) =  __low2float(a) >  __low2float(b) ?  __low2half(a) :  __low2half(b);
-    reinterpret_cast<half&>(ret.y) = __high2float(a) > __high2float(b) ? __high2half(a) : __high2half(b);
+    reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a),  __low2float(b)));
+    reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
     return ret;
 #endif // CUDART_VERSION >= CUDART_HMAX
 
@@ -326,7 +430,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
     GGML_UNUSED(a);
     GGML_UNUSED(b);
     NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 }
 
 static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
@@ -350,94 +454,6 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
 }
 #endif // CUDART_VERSION < 12000
 
-#if defined(GGML_USE_HIPBLAS)
-#define __CUDA_ARCH__ 1300
-
-#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
-    defined(__gfx1150__) || defined(__gfx1151__)
-#define RDNA3
-#endif
-
-#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
-    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
-#define RDNA2
-#endif
-
-#ifndef __has_builtin
-    #define __has_builtin(x) 0
-#endif
-
-typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
-typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
-static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
-    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
-    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
-#if __has_builtin(__builtin_elementwise_sub_sat)
-    const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
-    return reinterpret_cast<const int &>(c);
-#else
-    int8x4_t c;
-    int16_t tmp;
-#pragma unroll
-    for (int i = 0; i < 4; i++) {
-        tmp = va[i] - vb[i];
-        if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
-        if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
-        c[i] = tmp;
-    }
-    return reinterpret_cast<int &>(c);
-#endif // __has_builtin(__builtin_elementwise_sub_sat)
-}
-
-static __device__ __forceinline__ int __vsub4(const int a, const int b) {
-    return __vsubss4(a, b);
-}
-
-static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
-    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
-    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
-    unsigned int c;
-    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
-#pragma unroll
-    for (int i = 0; i < 4; ++i) {
-        vc[i] = va[i] == vb[i] ? 0xff : 0x00;
-    }
-    return c;
-}
-
-static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
-#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
-    c = __builtin_amdgcn_sdot4(a, b, c, false);
-#elif defined(RDNA3)
-    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
-#elif defined(__gfx1010__) || defined(__gfx900__)
-    int tmp1;
-    int tmp2;
-    asm("\n \
-        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
-        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
-        v_add3_u32 %0, %1, %2, %0 \n \
-        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
-        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
-        v_add3_u32 %0, %1, %2, %0 \n \
-        "
-        : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
-        : "v"(a), "v"(b)
-    );
-#else
-    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
-    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
-    c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
-#endif
-    return c;
-}
-#endif // defined(GGML_USE_HIPBLAS)
-
-#define FP16_AVAILABLE     defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
-    defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
-
-#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
-
 // TODO: move to ggml-common.h
 static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
 
index c8a11d173346454f05f1d5cf5ed70627db9efeea..7c486f4829bdd7868a99d028bbea2bbc0f4c8b2f 100644 (file)
 #define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
 #define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
 
-template<int D, int parallel_blocks> // D == head size
-__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
+template<int D, int ncols, int parallel_blocks> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_vec_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -44,55 +46,77 @@ static __global__ void flash_attn_vec_ext_f16(
 #if FP16_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
-    const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on.
-    const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic);
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic0);
     const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
     const half   * V_h   = (const half   *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *)  mask + ne11*ic;
+    const half   * maskh = (const half   *)  mask + ne11*ic0;
 
     const int stride_KV  = nb11 / sizeof(half);
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+    constexpr int nwarps = D / WARP_SIZE;
     const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
-    __builtin_assume(tid < nwarps*WARP_SIZE);
+    __builtin_assume(tid < D);
 
-    __shared__ half KQ[nwarps*WARP_SIZE];
-    KQ[tid] = -INFINITY;
+    __shared__ half KQ[ncols*D];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        KQ[j*D + tid] = -HALF_MAX_HALF;
+    }
     half2 * KQ2 = (half2 *) KQ;
 
-    half kqmax = -HALF_MAX_HALF;
-    half kqsum = 0.0f;
+    half kqmax[ncols];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqmax[j] = -HALF_MAX_HALF;
+    }
+    half kqsum[ncols] = {0.0f};
 
-    __shared__ half kqmax_shared[WARP_SIZE];
-    __shared__ half kqsum_shared[WARP_SIZE];
-    if (threadIdx.y == 0) {
-        kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
-        kqsum_shared[threadIdx.x] = 0.0f;
+    __shared__ half kqmax_shared[ncols][WARP_SIZE];
+    __shared__ half kqsum_shared[ncols][WARP_SIZE];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (threadIdx.y == 0) {
+            kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
+            kqsum_shared[j][threadIdx.x] = 0.0f;
+        }
     }
     __syncthreads();
 
     // Convert Q to half2 and store in registers:
-    half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
+    half2 Q_h2[ncols][D/(2*WARP_SIZE)];
 #pragma unroll
-    for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-        const int i = i0 + threadIdx.x;
-        if (i0 + WARP_SIZE > D/2 && i >= D/2) {
-            break;
-        }
+    for (int j = 0; j < ncols; ++j) {
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
 
-        Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y);
+            const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
+            Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
+        }
     }
 
-    half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
+    half2 VKQ[ncols] = {{0.0f, 0.0f}};
 
-    const int k_start  = parallel_blocks == 1 ? 0 : ip*D;
+    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
     for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
         // Calculate KQ tile and keep track of new maximum KQ values:
-        half kqmax_new = kqmax;
+
+        // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
+        // see https://github.com/ggerganov/llama.cpp/pull/7061 .
+        // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
+        half kqmax_new = kqmax[0];
+        half kqmax_new_arr[ncols];
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            kqmax_new_arr[j] = kqmax[j];
+        }
+
 #pragma unroll
         for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
             const int i_KQ = i_KQ_0 + threadIdx.y;
@@ -101,89 +125,112 @@ static __global__ void flash_attn_vec_ext_f16(
                 break;
             }
 
-            half2 sum2 = make_half2(0.0f, 0.0f);
+            half2 sum2[ncols] = {{0.0f, 0.0f}};
 #pragma unroll
             for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
                 const int k_KQ = k_KQ_0 + threadIdx.x;
-                if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) {
-                    break;
-                }
 
                 const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
-                sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+#pragma unroll
+                for (int j = 0; j < ncols; ++j) {
+                    sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
+                }
             }
 
-            sum2 = warp_reduce_sum(sum2);
-            half sum = __low2half(sum2) + __high2half(sum2);
-            sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
-            kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
-            if (threadIdx.x == 0) {
-                KQ[i_KQ] = sum;
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                sum2[j] = warp_reduce_sum(sum2[j]);
+                half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
+                sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+
+                if (ncols == 1) {
+                    kqmax_new        = ggml_cuda_hmax(kqmax_new,        sum);
+                } else {
+                    kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
+                }
+
+                if (threadIdx.x == 0) {
+                    KQ[j*D + i_KQ] = sum;
+                }
             }
         }
 
-        kqmax_new = warp_reduce_max(kqmax_new);
-        if (threadIdx.x == 0) {
-            kqmax_shared[threadIdx.y] = kqmax_new;
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
+
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = kqmax_new_j;
+            }
         }
+
         __syncthreads();
-        kqmax_new = kqmax_shared[threadIdx.x];
-        kqmax_new = warp_reduce_max(kqmax_new);
 
-        const half KQ_max_scale = hexp(kqmax - kqmax_new);
-        kqmax = kqmax_new;
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            half kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
 
-        const half val = hexp(KQ[tid] - kqmax);
-        kqsum = kqsum*KQ_max_scale + val;
-        KQ[tid] = val;
+            const half val = hexp(KQ[j*D + tid] - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale + val;
+            KQ[j*D + tid] = val;
 
-        VKQ *= __half2half2(KQ_max_scale);
+            VKQ[j] *= __half2half2(KQ_max_scale);
+        }
 
         __syncthreads();
 
-        if (tid < D) {
 #pragma unroll
-            for (int k0 = 0; k0 < D; k0 += 2) {
-                if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
-                    break;
-                }
+        for (int k0 = 0; k0 < D; k0 += 2) {
+            if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
+                break;
+            }
 
-                half2 V_k;
-                reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
-                reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
-                VKQ += V_k*KQ2[k0/2];
+            half2 V_k;
+            reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
+            reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
             }
         }
 
         __syncthreads();
     }
 
-    if (tid >= D) {
-        kqsum = 0.0f;
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqsum[j] = warp_reduce_sum(kqsum[j]);
+        if (threadIdx.x == 0) {
+            kqsum_shared[j][threadIdx.y] = kqsum[j];
+        }
     }
 
-    kqsum = warp_reduce_sum(kqsum);
-    if (threadIdx.x == 0) {
-        kqsum_shared[threadIdx.y] = kqsum;
-    }
     __syncthreads();
-    kqsum = kqsum_shared[threadIdx.x];
-    kqsum = warp_reduce_sum(kqsum);
 
-    if (tid >= D) {
-        return;
-    }
+#pragma unroll
+    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
+        kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
 
-    half dst_val = (__low2half(VKQ) + __high2half(VKQ));
-    if (parallel_blocks == 1) {
-        dst_val /= kqsum;
+        half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
+        if (parallel_blocks == 1) {
+            dst_val /= kqsum[j_VKQ];
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
     }
-    dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
 
-    if (parallel_blocks == 1 || tid != 0) {
-        return;
+    if (parallel_blocks != 1 && tid != 0) {
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
+        }
     }
-    dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
 #else
    NO_DEVICE_CODE;
 #endif // FP16_AVAILABLE
@@ -191,7 +238,9 @@ static __global__ void flash_attn_vec_ext_f16(
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
 template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -573,7 +622,9 @@ static __global__ void flash_attn_ext_f16(
 }
 
 template<int D, int parallel_blocks> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_combine_results(
         const float  * __restrict__ VKQ_parts,
         const float2 * __restrict__ VKQ_meta,
@@ -642,7 +693,7 @@ static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
 static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
 static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
 
-template <int D, int parallel_blocks> void launch_fattn_vec_f16(
+template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
         const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
         ggml_cuda_pool & pool, cudaStream_t main_stream
 ) {
@@ -656,13 +707,13 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
 
     constexpr int  nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
     const     dim3 block_dim(WARP_SIZE, nwarps, 1);
-    const     dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
+    const     dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
     const     int  shmem = 0;
 
     float scale;
     memcpy(&scale, KQV->op_params, sizeof(float));
 
-    flash_attn_vec_ext_f16<D, parallel_blocks>
+    flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
         <<<blocks_num, block_dim, shmem, main_stream>>> (
                 (const char *) Q->data,
                 (const char *) K->data,
@@ -783,10 +834,99 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     ggml_cuda_set_device(ctx.device);
 
+    const int cc  = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
 
     const int32_t precision = KQV->op_params[1];
 
+    if (!fp16_mma_available(cc)) {
+        GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+        GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
+
+        if (Q->ne[1] == 1) {
+            constexpr int cols_per_block = 1;
+            constexpr int parallel_blocks = 4;
+            switch (Q->ne[0]) {
+                case 64:
+                    launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                    break;
+                case 128:
+                    launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                    break;
+                default:
+                    GGML_ASSERT(false);
+                    break;
+            }
+            return;
+        }
+
+        if (Q->ne[1] == 2) {
+            constexpr int cols_per_block = 2;
+            constexpr int parallel_blocks = 4;
+            switch (Q->ne[0]) {
+                case 64:
+                    launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                    break;
+                case 128:
+                    launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                    break;
+                default:
+                    GGML_ASSERT(false);
+                    break;
+            }
+            return;
+        }
+
+        if (Q->ne[1] <= 4) {
+            constexpr int cols_per_block = 4;
+            constexpr int parallel_blocks = 4;
+            switch (Q->ne[0]) {
+                case 64:
+                    launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                    break;
+                case 128:
+                    launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                    break;
+                default:
+                    GGML_ASSERT(false);
+                    break;
+            }
+            return;
+        }
+
+        if (Q->ne[1] <= 8) {
+            constexpr int cols_per_block = 8;
+            constexpr int parallel_blocks = 4;
+            switch (Q->ne[0]) {
+                case 64:
+                    launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                    break;
+                case 128:
+                    launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                    break;
+                default:
+                    GGML_ASSERT(false);
+                    break;
+            }
+            return;
+        }
+
+        constexpr int cols_per_block = 8;
+        constexpr int parallel_blocks = 1;
+        switch (Q->ne[0]) {
+            case 64:
+                launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                break;
+            case 128:
+                launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                break;
+            default:
+                GGML_ASSERT(false);
+                break;
+        }
+        return;
+    }
+
     if (precision != GGML_PREC_DEFAULT) {
         if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
             constexpr int cols_per_block = 16;
@@ -845,16 +985,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+        constexpr int cols_per_block = 1;
         constexpr int parallel_blocks = 4;
         switch (Q->ne[0]) {
             case 64:
-                launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
                 break;
             case 128:
-                launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
                 break;
             case 256:
-                launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
+                launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
                 break;
             default:
                 GGML_ASSERT(false);