]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml-cuda: native bf16 flash attention for vec kernel (llama/20525)
authorPatrick Buckley <redacted>
Sun, 22 Mar 2026 10:05:51 +0000 (03:05 -0700)
committerGeorgi Gerganov <redacted>
Sat, 28 Mar 2026 11:39:09 +0000 (13:39 +0200)
* ggml-cuda: native bf16 flash attention for vec and tile kernels

mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo

* ggml-cuda: address code owner review feedback

reverted tile kernel changes to avoid larger refactor

* fix ci failures on turing and hip

* fix bf16 vec kernel compile on hip v_dot2 platforms

* add comments

---------

Co-authored-by: Johannes Gäßler <redacted>
21 files changed:
src/ggml-cuda/CMakeLists.txt
src/ggml-cuda/convert.cuh
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-vec.cuh
src/ggml-cuda/fattn.cu
src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/generate_cu_files.py
src/ggml-hip/CMakeLists.txt
src/ggml-musa/CMakeLists.txt

index 262f88204e01cf07a8d8c6bab23efb2b46435dd3..419862101d112a96efb3806f689c35390e4f4e0c 100644 (file)
@@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND)
         list(APPEND GGML_SOURCES_CUDA ${SRCS})
         add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
     else()
-        file(GLOB   SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
-        list(APPEND GGML_SOURCES_CUDA ${SRCS})
-        file(GLOB   SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
-        list(APPEND GGML_SOURCES_CUDA ${SRCS})
-        file(GLOB   SRCS "template-instances/fattn-vec*f16-f16.cu")
-        list(APPEND GGML_SOURCES_CUDA ${SRCS})
+        list(APPEND GGML_SOURCES_CUDA
+            template-instances/fattn-vec-instance-f16-f16.cu
+            template-instances/fattn-vec-instance-q4_0-q4_0.cu
+            template-instances/fattn-vec-instance-q8_0-q8_0.cu
+            template-instances/fattn-vec-instance-bf16-bf16.cu)
     endif()
 
     ggml_add_backend_library(ggml-cuda
index 09f9a33f909187b5b058177e1b0c73ff5ccd2c6d..b8caeacf094286b5e841c62b92fc5892c216c4b4 100644 (file)
@@ -41,6 +41,12 @@ template<typename dst_t, typename src_t>
         return __bfloat162float(x);
     } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
         return __float22half2_rn(x);
+    } else if constexpr(std::is_same_v<src_t, nv_bfloat162> && std::is_same_v<dst_t, float2>) {
+#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+        return __bfloat1622float2(x);
+#else
+        return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
+#endif // !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
     } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
         // bypass compile error on cuda 12.0.1
 #ifdef GGML_USE_HIP
index e9abdf288c43f42db533b9457e1786342ecafb07..c59a4db39993b7c1ab7abf812c2a23136fa1b134 100644 (file)
@@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
     return sum;
 }
 
+template <int D, int nthreads>
+static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
+
+    const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c;
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+
+    constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
+        __align__(16) nv_bfloat162 tmp[cpy_ne];
+        ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
+#pragma unroll
+        for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
+#ifdef V_DOT2_F32_F16_AVAILABLE
+            // FIXME replace macros in vector FA kernel with templating and use FP32 for BF16
+            ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]));
+#else
+            ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
+#endif // V_DOT2_F32_F16_AVAILABLE
+        }
+    }
+
+    return sum;
+}
+
 template<int D, int nthreads>
 static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
     }
 }
 
+template <typename T, int ne>
+static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    static_assert(std::is_same_v<T, float>, "BF16 V dequantization only supports float output");
+    static_assert(ne % 2 == 0, "bad ne");
+    __align__(16) nv_bfloat162 tmp[ne/2];
+    ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0);
+    float2 * dst_f2 = (float2 *) dst;
+#pragma unroll
+    for (int l = 0; l < ne/2; ++l) {
+        dst_f2[l] = ggml_cuda_cast<float2>(tmp[l]);
+    }
+}
+
 template <typename T, int ne>
 static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
     const block_q4_0 * x = (const block_q4_0 *) vx;
@@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
         return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
     } else if constexpr (type_K == GGML_TYPE_Q8_0) {
         return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
+    } else if constexpr (type_K == GGML_TYPE_BF16) {
+        return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
     } else {
         static_assert(type_K == -1, "bad type");
         return nullptr;
@@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
         return dequantize_V_q5_1<T, ne>;
     } else if constexpr (type_V == GGML_TYPE_Q8_0) {
         return dequantize_V_q8_0<T, ne>;
+    } else if constexpr (type_V == GGML_TYPE_BF16) {
+        return dequantize_V_bf16<float, ne>;
     } else {
         static_assert(type_V == -1, "bad type");
         return nullptr;
index 7cbe32633e5877cd1c68a83fde5719950c23363f..f0bd42a5761232952df3e6e529ec607a718bda54 100644 (file)
@@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec(
 #endif // GGML_USE_HIP
 
     constexpr int nthreads    = ggml_cuda_fattn_vec_get_nthreads_device();
-    constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
-    constexpr int nthreads_V  = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
+    constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q;
+    constexpr int nthreads_V  = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q;
 
     static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
     static_assert(WARP_SIZE % nthreads_V  == 0, "bad nthreads_V");
 
-    constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
+    constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4;
     constexpr int V_cols_per_iter   = WARP_SIZE / nthreads_V;
 
     constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
-    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16;
 #ifdef V_DOT2_F32_F16_AVAILABLE
     constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half,  V_rows_per_thread>();
 #else
@@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec(
 #pragma unroll
             for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
                 half2 tmp[V_rows_per_thread/2];
-                dequantize_V(V + k*nb21, tmp,
-                    2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
+                if constexpr (type_V == GGML_TYPE_BF16) {
+                    float2 tmp_f[V_rows_per_thread/2];
+                    dequantize_V(V + k*nb21, tmp_f,
+                        2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
+#pragma unroll
+                    for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
+                        tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]);
+                    }
+                } else {
+                    dequantize_V(V + k*nb21, tmp,
+                        2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
+                }
 #pragma unroll
                 for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
 #pragma unroll
@@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten
     extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
     extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
     extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \
 
 EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
 EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
@@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
 EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
 EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
 EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16)
 
 EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
 EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
@@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
 EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
 EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
 EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16)
 
 EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
 EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
@@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
 EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
 EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
 EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)
index 85c177f496fad2bb6ecbee64f06fbbc8ea759027..a25a890db6dafca57f7b9fef0d119ee40400b295 100644 (file)
@@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16)
 
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_0)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0)
 
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_1)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
@@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1)
 
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_0)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
@@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0)
 
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_1)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
@@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1)
 
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q8_0)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
@@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_BF16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
 #else
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
     FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
 #endif // GGML_CUDA_FA_ALL_QUANTS
 
     GGML_ABORT("fatal error");
@@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
 #endif // GGML_CUDA_FA_ALL_QUANTS
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q8_0:
+        case GGML_TYPE_BF16:
             break;
         default:
             return BEST_FATTN_KERNEL_NONE;
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu
new file mode 100644 (file)
index 0000000..3a2fa99
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu
new file mode 100644 (file)
index 0000000..60f0f6f
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu
new file mode 100644 (file)
index 0000000..489e05f
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu
new file mode 100644 (file)
index 0000000..6fa3c26
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu
new file mode 100644 (file)
index 0000000..421027f
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu
new file mode 100644 (file)
index 0000000..abbc943
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu
new file mode 100644 (file)
index 0000000..d641f85
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu
new file mode 100644 (file)
index 0000000..d1071dc
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu
new file mode 100644 (file)
index 0000000..8afda31
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu
new file mode 100644 (file)
index 0000000..506864a
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu
new file mode 100644 (file)
index 0000000..0bbda83
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu
new file mode 100644 (file)
index 0000000..79be24d
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16);
diff --git a/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu
new file mode 100644 (file)
index 0000000..45636e5
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16);
index e382df1ae206472917b63f08c72c52a5ac691b27..3b5ab12fc408648c0a9363382ebb94be87d95b79 100755 (executable)
@@ -5,7 +5,7 @@ import os
 
 HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
 
-TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
+TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
 
 SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
index f96c6e09a9bf57113586123b552b28412ce80dee..291b4837455f4645f73ce6482eb1eea351c2b6b2 100644 (file)
@@ -71,12 +71,11 @@ if (GGML_CUDA_FA_ALL_QUANTS)
     list(APPEND GGML_SOURCES_ROCM ${SRCS})
     add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
 else()
-    file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
-    list(APPEND GGML_SOURCES_ROCM ${SRCS})
-    file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
-    list(APPEND GGML_SOURCES_ROCM ${SRCS})
-    file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
-    list(APPEND GGML_SOURCES_ROCM ${SRCS})
+    list(APPEND GGML_SOURCES_ROCM
+        ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
+        ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
+        ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
+        ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
 endif()
 
 ggml_add_backend_library(ggml-hip
index d76cb51977f9000cda679fef04b8d8e072a2bd09..cc53c812ce5f47a2b955d8cc9af4fdcb4e62bbd7 100644 (file)
@@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND)
         list(APPEND GGML_SOURCES_MUSA ${SRCS})
         add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
     else()
-        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
-        list(APPEND GGML_SOURCES_MUSA ${SRCS})
-        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
-        list(APPEND GGML_SOURCES_MUSA ${SRCS})
-        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
-        list(APPEND GGML_SOURCES_MUSA ${SRCS})
+        list(APPEND GGML_SOURCES_MUSA
+            ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
+            ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
+            ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
+            ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
     endif()
 
     set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)