]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sycl : enhance fattn perf (llama/21185)
authorNeo Zhang <redacted>
Tue, 31 Mar 2026 10:31:50 +0000 (18:31 +0800)
committerGeorgi Gerganov <redacted>
Wed, 1 Apr 2026 13:00:26 +0000 (16:00 +0300)
src/ggml-sycl/fattn-tile.hpp

index 29fd0f8c9ece520cb64db5039774d5f836132650..c4d24613a5577dd9373637b144b1574a1c114ae7 100644 (file)
@@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
     GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
     GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2,  64,  64)
 
     return 0;
 }
@@ -310,11 +311,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const
                                                       sycl::half2 * const __restrict__ tile_KV,
                                                       const int stride_KV,
                                                       const int i_sup) {
+    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
     constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
     constexpr int cpy_ne = cpy_nb / 4;
 
     auto load = [&] (const int n) {
-        auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
         const int stride_j = warp_size >> n;
 
         if (stride_j == 0) {
@@ -455,7 +456,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
 
     flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
         (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
-    item_ct1.barrier();
+    item_ct1.barrier(sycl::access::fence_space::local_space);
 
 #ifdef SYCL_FAST_FP16
     static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
@@ -505,7 +506,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
     }
 
     if (k_KQ_0 + nbatch_K < DKQ) {
-        item_ct1.barrier();  // Sync not needed on last iteration.
+        item_ct1.barrier(sycl::access::fence_space::local_space);  // Sync not needed on last iteration.
     }
 }
 
@@ -545,7 +546,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
                                                  const int         k_VKQ_max,
                                                  const int         col_Q_0,
                                                  float *           KQ_max_new_shared) {
-    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
     constexpr int cpy_nb   = ggml_sycl_get_max_cpy_bytes();
     constexpr int cpy_ne = cpy_nb / 4;
 
@@ -620,14 +621,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
     }
 
     if constexpr (np == 1) {
-        item_ct1.barrier();
+        item_ct1.barrier(sycl::access::fence_space::local_space);
     } else {
         static_assert(cpw == 1, "bad cpw");
 
         if (item_ct1.get_local_id(2) == 0) {
             KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
         }
-        item_ct1.barrier();
+        item_ct1.barrier(sycl::access::fence_space::local_space);
         KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
         KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
     }
@@ -697,7 +698,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
     for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
         flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
             (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
-        item_ct1.barrier();
+        item_ct1.barrier(sycl::access::fence_space::local_space);
 
 #ifdef SYCL_FAST_FP16
 #pragma unroll
@@ -765,7 +766,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
             }
         }
 #endif // SYCL_FAST_FP16
-        item_ct1.barrier();
+        item_ct1.barrier(sycl::access::fence_space::local_space);
     }
 }
 
@@ -972,7 +973,7 @@ static void flash_attn_tile(const char *  Q,
         }
     }
 
-    item_ct1.barrier();
+    item_ct1.barrier(sycl::access::fence_space::local_space);
 
     // Main loop over KV cache:
     const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
@@ -1051,7 +1052,7 @@ static void flash_attn_tile(const char *  Q,
             return;
         }
 
-        item_ct1.barrier();
+        item_ct1.barrier(sycl::access::fence_space::local_space);
 
 #pragma unroll
         for (int ip = 1; ip < np; ++ip) {
@@ -1193,37 +1194,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
 
     constexpr size_t nbytes_shared = 0;
 
-    if constexpr (DV <= 256) {
-        if (Q->ne[1] > 16/ncols2) {
-            constexpr int cols_per_block = 32;
-            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
-            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
-            launch_fattn<DV, cols_per_block/ncols2, ncols2,
-                flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
-                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
-            return;
+    if (DV < 512 && Q->ne[1] < 32) {
+        if constexpr (ncols2 <= 32) {
+            if (Q->ne[1] > 16/ncols2) {
+                constexpr int cols_per_block = 32;
+                const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+                const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+                launch_fattn<DV, cols_per_block/ncols2, ncols2,
+                    flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+                    (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+                return;
+            }
         }
-    }
-
-    if (Q->ne[1] > 8/ncols2) {
-        constexpr int cols_per_block = 16;
-        const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
-        const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
-        launch_fattn<DV, cols_per_block/ncols2, ncols2,
-            flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
-            (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
-        return;
-    }
-
-    if constexpr (ncols2 <= 8) {
-        if (Q->ne[1] > 4/ncols2) {
-            constexpr int cols_per_block = 8;
-            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
-            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
-            launch_fattn<DV, cols_per_block/ncols2, ncols2,
-                flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
-                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
-            return;
+        if constexpr (ncols2 <= 16) {
+            if (Q->ne[1] > 8/ncols2) {
+                constexpr int cols_per_block = 16;
+                const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+                const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+                launch_fattn<DV, cols_per_block/ncols2, ncols2,
+                    flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+                    (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+                return;
+            }
+        }
+        if constexpr (ncols2 <= 8) {
+            if (Q->ne[1] > 4/ncols2) {
+                constexpr int cols_per_block = 8;
+                const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+                const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+                launch_fattn<DV, cols_per_block/ncols2, ncols2,
+                    flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+                    (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+                return;
+            }
         }
     }