]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: generalized (mma) FA, add Volta support (#17505)
authorJohannes Gäßler <redacted>
Wed, 3 Dec 2025 15:57:05 +0000 (16:57 +0100)
committerGitHub <redacted>
Wed, 3 Dec 2025 15:57:05 +0000 (16:57 +0100)
* CUDA: generalized (mma) FA, add Volta support

* use struct for MMA FA kernel config

---------

Co-authored-by: Aman Gupta <aman>
ggml/include/ggml.h
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-cuda/fattn-tile.cuh
ggml/src/ggml-cuda/fattn-vec.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cu
ggml/src/ggml-cuda/fattn-wmma-f16.cuh
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/mma.cuh
ggml/src/ggml-cuda/mmf.cuh

index 48da68fe7e3eeaf8d4411847cce245b26a6f85c5..e665614670d4802a8c1ead1a62a4aa12e7dd87a8 100644 (file)
@@ -2279,7 +2279,7 @@ extern "C" {
             float                 stop,
             float                 step);
 
-#define GGML_KQ_MASK_PAD 64
+#define GGML_KQ_MASK_PAD 1
 
     // q:    [n_embd_k, n_batch,     n_head,    ne3 ]
     // k:    [n_embd_k, n_kv,        n_head_kv, ne3 ]
index 5cdd4bb211492132dfbcdf7e6ffaa05b77ca9a25..02443b8c638294ee143e8c6bd8cc354be78a30df 100644 (file)
@@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
                             const int32_t nb01, const int32_t nb02, const int32_t nb03,
         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
                             const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max(
 template<int D, int ncols1, int ncols2> // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
-        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
+        const int nbatch_fa) {
     constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
@@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup(
 
     const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
 
-    const int iter_k = ne11 / FATTN_KQ_STRIDE;
-    const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+    const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j = (ne01 + (ncols1    - 1)) / ncols1;
 
     const int kbc0      = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
     const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
@@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results(
 template <int DV, int ncols1, int ncols2>
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
-    const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
+    const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
@@ -790,8 +791,6 @@ void launch_fattn(
     GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
 
     GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
-    GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
-        "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
 
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
@@ -915,7 +914,7 @@ void launch_fattn(
 
         dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
     } else {
-        const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
+        const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
 
         // parallel_blocks must not be larger than what the tensor size allows:
         parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -970,6 +969,9 @@ void launch_fattn(
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+    // TODO other tensor dimensions after removal of WMMA kernel:
+    const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
+
     GGML_ASSERT(block_dim.x % warp_size == 0);
     fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
         (const char *) Q->data,
@@ -980,7 +982,7 @@ void launch_fattn(
         KV_max.ptr,
         !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
-        Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
+        Q->ne[0], ne01,     Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
         nb21, nb22, nb23,
         mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
@@ -995,7 +997,7 @@ void launch_fattn(
 
             flash_attn_stream_k_fixup<DV, ncols1, ncols2>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
-                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
         }
     } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(DV, 1, 1);
index 57defb0c629d6984f2066d29d00b3a850a06c2da..b6250cf7949d160e1b5bc0de6a915dbb722ba7e2 100644 (file)
 
 using namespace ggml_cuda_mma;
 
-typedef tile<16,  8, half2> tile_A;
-typedef tile< 8,  8, half2> tile_B;
-typedef tile<16,  8, half2> tile_B_16;
-typedef tile<16,  8, float> tile_C_KQ;
-typedef tile<16, 16, float> tile_C_KQ_16;
-typedef tile<16,  4, half2> tile_C_VKQ;
-typedef tile<16,  8, half2> tile_C_VKQ_16;
-
-// Config options for specific head sizes.
+// Config options for the MMA kernel.
 // Should not affect results, only speed/register pressure/shared memory use.
-//
-// nbatch_fa:      number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
-// nwarps_max:     maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
-// Q_in_reg:       whether the Q values should be kept permanently in registers.
-// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
-// nbatch_K2:      number of K half2 values in direction of DKQ to load in parallel.
-// nbatch_V2:      number of V half2 values in direction of DV to load in parallel.
-// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
-
-template <int DKQ, int DV>
-struct fattn_mma_f16_config;
-
-template <>
-struct fattn_mma_f16_config< 64,  64> {
-    static constexpr int  nbatch_fa      = 64;
-    static constexpr int  nwarps_max     = 4;
-    static constexpr bool Q_in_reg       = true;
-    static constexpr int  nstages_target = 2;
-
-    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
-        return 32;
-    }
-
-    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
-        return 32;
-    }
-
-    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
-        return 32;
-    }
-
-    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
-        return 32;
-    }
-
-    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
-        return 32;
-    }
-
-    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
-        return 32;
-    }
+struct fattn_mma_config {
+    int  nthreads;       // Number of threads per CUDA block.
+    int  occupancy;      // Targeted occupancy for the MMA kernel.
+    int  nbatch_fa;      // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
+    int  nbatch_K2;      // Number of K half2 values in direction of DKQ to load in parallel.
+    int  nbatch_V2;      // Number of V half2 values in direction of DV to load in parallel.
+    int  nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.
+    int  nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.
+    bool Q_in_reg;       // Whether the Q values should be kept permanently in registers.
+
+    constexpr __host__ __device__ fattn_mma_config(
+            int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :
+        nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),
+        nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}
 };
 
-template <>
-struct fattn_mma_f16_config< 80,  80> {
-    static constexpr int  nbatch_fa      = 64;
-    static constexpr int  nwarps_max     = 4;
-    static constexpr bool Q_in_reg       = true;
-    static constexpr int  nstages_target = 2;
-
-    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
-        return 40;
-    }
-
-    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
-        return 40;
-    }
-
-    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
-        return 40;
-    }
-
-    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
-        return 40;
-    }
-
-    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
-        return 40;
-    }
-
-    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
-        return 40;
-    }
-};
-
-template <>
-struct fattn_mma_f16_config< 96,  96> {
-    static constexpr int  nbatch_fa      = 64;
-    static constexpr int  nwarps_max     = 4;
-    static constexpr bool Q_in_reg       = true;
-    static constexpr int  nstages_target = 2;
-
-    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
-        return 48;
-    }
-
-    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
-        return 48;
-    }
-
-    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
-        return 48;
-    }
-
-    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
-        return 48;
-    }
-
-    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
-        return 48;
-    }
+#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \
+    if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) {                                                                                                       \
+        static_assert((nthreads_)       % 32 == 0 && (nthreads_)       <= 512, "bad nthreads");                                                                    \
+        static_assert(                               (occupancy_)      <=   8, "bad occupancy");                                                                   \
+        static_assert((nbatch_fa_)      % 32 == 0 && (nbatch_fa_)      <= 256, "bad nbatch_fa");                                                                   \
+        static_assert((nbatch_K2_)      %  4 == 0 && (nbatch_K2_)      <= 512, "bad nbatch_K2");                                                                   \
+        static_assert((nbatch_V2_)      %  4 == 0 && (nbatch_V2_)      <= 256, "bad nbatch_V2");                                                                   \
+        static_assert((nbatch_combine_) %  4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine");                                                              \
+        static_assert((nstages_target_)      >= 1 && (nstages_target_) <=   2, "bad nstages_target");                                                              \
+        return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)};           \
+    }                                                                                                                                                              \
+
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64,  8, 128, 2, 128,  32,  32,  32, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 16, 128, 2,  64,  32,  32,  32, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 32, 128, 2,  64,  32,  32,  32, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 64, 128, 2,  64,  32,  32,  32, 2, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80,  8, 128, 2, 128,  40,  40,  40, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 16, 128, 2,  64,  40,  40,  40, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 32, 128, 2,  64,  40,  40,  40, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 64, 128, 2,  64,  40,  40,  40, 2, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96,  8, 128, 2, 128,  48,  48,  48, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 16, 128, 2,  64,  48,  48,  48, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 32, 128, 2,  64,  48,  48,  48, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 64, 128, 2,  64,  48,  48,  48, 2, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112,  8, 128, 2, 128,  56,  56,  56, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2,  64,  56,  56,  56, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2,  64,  56,  56,  56, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2,  64,  56,  56,  56, 2, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128,  8, 128, 2, 128,  64,  64,  64, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2,  64,  64,  64,  64, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2,  64,  64,  64,  64, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2,  64,  64,  64,  64, 2, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256,  8,  64, 4,  64, 128, 128, 128, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16,  64, 4,  32, 128, 128, 128, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  32, 128, 128, 128, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  32, 128, 128, 128, 2, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
+
+    return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
+}
 
-    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
-        return 48;
-    }
-};
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256,  8, 128, 2,  64, 128, 128, 128, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2,  64, 128, 128, 128, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);
 
-template <>
-struct fattn_mma_f16_config<112, 112> {
-    static constexpr int  nbatch_fa      = 64;
-    static constexpr int  nwarps_max     = 4;
-    static constexpr bool Q_in_reg       = true;
-    static constexpr int  nstages_target = 2;
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32,  96,  64, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
 
-    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
-        return 56;
-    }
+    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
 
-    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
-        return 56;
-    }
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256,  64, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256,  64, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128,  64, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128,  64, 1, false);
 
-    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
-        return 56;
-    }
+    // TODO tune specifically for Volta
+    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
 
-    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
-        return 56;
+static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
+    if (ampere_mma_available(cc)) {
+        return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
     }
-
-    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
-        return 56;
+    if (turing_mma_available(cc)) {
+        return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
     }
+    GGML_ASSERT(volta_mma_available(cc));
+    return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
+}
 
-    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
-        return 56;
-    }
-};
+static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {
+#if defined(AMPERE_MMA_AVAILABLE)
+    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+#elif defined(TURING_MMA_AVAILABLE)
+    return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
+#elif defined(VOLTA_MMA_AVAILABLE)
+    return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
+#else
+    GGML_UNUSED_VARS(DKQ, DV, ncols);
+    return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
+#endif // defined(AMPERE_MMA_AVAILABLE)
+}
 
-template <>
-struct fattn_mma_f16_config<128, 128> {
-    static constexpr int  nbatch_fa      = 64;
-    static constexpr int  nwarps_max     = 4;
-    static constexpr bool Q_in_reg       = true;
-    static constexpr int  nstages_target = 2;
+static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads;
+}
 
-    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
-        return 64;
-    }
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads;
+}
 
-    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
-        return 64;
-    }
+static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy;
+}
 
-    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
-        return 64;
-    }
+static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy;
+}
 
-    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
-        return 64;
-    }
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa;
+}
 
-    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
-        return 64;
-    }
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa;
+}
 
-    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
-        return 64;
-    }
-};
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2;
+}
 
-template <>
-struct fattn_mma_f16_config<256, 256> {
-    static constexpr int  nbatch_fa      = 32;
-    static constexpr int  nwarps_max     = 4;
-    static constexpr bool Q_in_reg       = true;
-    static constexpr int  nstages_target = 2;
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2;
+}
 
-    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
-        return 128;
-    }
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2;
+}
 
-    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
-        return 128;
-    }
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2;
+}
 
-    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
-        return 128;
-    }
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine;
+}
 
-    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
-        return 128;
-    }
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine;
+}
 
-    static int get_nbatch_combine_host(const int cc, const int ncols) {
-        if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
-            return ncols <= 16 ? 128 : 64;
-        }
-        return 64;
-    }
+static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target;
+}
 
-    static constexpr __device__ int get_nbatch_combine_device(int ncols) {
-#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
-        return ncols <= 16 ? 128 : 64;
-#else
-        GGML_UNUSED(ncols);
-        return 128;
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
-    }
-};
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target;
+}
 
-template <>
-struct fattn_mma_f16_config<576, 512> {
-    static constexpr int  nbatch_fa      = 32;
-    static constexpr int  nwarps_max     = 8;
-    static constexpr bool Q_in_reg       = false;
-    static constexpr int  nstages_target = 1;
+static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg;
+}
 
-    static int get_nbatch_K2_host(const int cc, const int ncols) {
-        if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
-            return ncols <= 16 ? 96 : 160;
-        }
-        return ncols <= 16 ? 288 : 160;
-    }
+static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {
+    return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
+}
 
-    static constexpr __device__ int get_nbatch_K2_device(int ncols) {
-#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
-        return ncols <= 16 ? 96 : 160;
-#else
-        return ncols <= 16 ? 288 : 160;
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
-    }
+// ------------------------------------------------------------------------------------------------------------------
 
-    static int get_nbatch_V2_host(const int cc, const int ncols) {
-        if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
-            return ncols <= 16 ? 64 : 128;
-        }
-        return ncols <= 16 ? 256 : 128;
-    }
+static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
+    return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0;
+}
 
-    static constexpr __device__ int get_nbatch_V2_device(int ncols) {
-#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
-        return ncols <= 16 ? 64 : 128;
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) {
+#ifdef CP_ASYNC_AVAILABLE
+    return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0;
 #else
-        return ncols <= 16 ? 256 : 128;
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
-    }
-
-    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
-        return 128;
-    }
-
-    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
-        return 128;
-    }
-};
+    GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2);
+    return 0;
+#endif // CP_ASYNC_AVAILABLE
+}
 
 // ------------------------------------------------------------------------------------------------------------------
 
-template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
+template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
 static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
-        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
-
+        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
     // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
     // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
-
-    if (use_cp_async) {
+    if constexpr (use_cp_async) {
+        static_assert(!oob_check, "OOB check not compatible with cp_async");
         constexpr int preload = 64;
         constexpr int h2_per_chunk = 16/sizeof(half2);
         const int chunks_per_row = D2 / h2_per_chunk;
@@ -315,9 +242,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
                 }
             }
         };
-        ggml_cuda_unroll<5>{}(load);
+        // 1: max 32*16=512 bytes, 256 half
+        // 2: max 16*16=256 bytes, 128 half
+        // 3: max  8*16=128 bytes,  64 half
+        // 4: max  4*16= 64 bytes,  32 half
+        // 5: max  2*16= 32 bytes,  16 half
+        // 6: max  1*16= 16 bytes,   8 half
+        ggml_cuda_unroll<6>{}(load);
     } else {
-        static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
+        // TODO use ggml_cuda_memcpy_1
         auto load = [&] __device__ (const int n) {
             const int stride_k = WARP_SIZE >> n;
             const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
@@ -340,20 +273,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
-                    tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
+                    tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
                 }
             }
         };
-        ggml_cuda_unroll<3>{}(load);
+        // 1: max 32* 4=128 bytes,  64 half
+        // 2: max 16* 4= 64 bytes,  32 half
+        // 3: max  8* 4= 32 bytes,  16 half
+        // 4: max  4* 4= 16 bytes,   8 half
+        ggml_cuda_unroll<4>{}(load);
     }
 }
 
-template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
+template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
 static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
-        const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
-    static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
-
-    if (use_cp_async) {
+        const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
+        const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
+    if constexpr (use_cp_async) {
+        static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
+        static_assert(!oob_check, "OOB check incompatible with cp_async");
         constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
         constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
         constexpr int stride_j = nwarps * cols_per_warp;
@@ -361,50 +299,85 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
         const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
 
 #pragma unroll
-        for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
-            const int j = j0 + threadIdx.y*cols_per_warp +
-                (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp));
+        for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
+            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+            const int j_vram = fastmodulo(j0 + j_sram, ne01);
 
-            if (j0 + stride_j > ncols1 && j >= ncols1) {
+            if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
                 break;
             }
 
-            const int i = 4 * (threadIdx.x % (nbatch_fa/8));
+            const int i = 8 * (threadIdx.x % (nbatch_fa/8));
 
-            cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
+            cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
         }
-        return;
-    }
+    } else if constexpr (oob_check) {
+#pragma unroll
+        for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
+            const int j_sram = j1 + threadIdx.y;
+            const int j_vram = fastmodulo(j0 + j_sram, ne01);
+
+            if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
+                break;
+            }
 
-    constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
-    constexpr int stride_j = nwarps * cols_per_warp;
 #pragma unroll
-    for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
-        const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp));
+            for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
 
-        if (j0 + stride_j > ncols1 && j >= ncols1) {
-            break;
+                tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
+            }
         }
+    } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
+        constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
+        constexpr int stride_j = nwarps * cols_per_warp;
+#pragma unroll
+        for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
+            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+            const int j_vram = fastmodulo(j0 + j_sram, ne01);
 
-        const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp);
+            if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
+                break;
+            }
+
+            const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
 
-        tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i];
+            ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
+        }
+    } else {
+#pragma unroll
+        for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
+            const int j_sram = j1 + threadIdx.y;
+            const int j_vram = fastmodulo(j0 + j_sram, ne01);
+
+            if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
+                break;
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
+                const int i = i0 + 2*threadIdx.x;
+
+                ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
+            }
+        }
     }
 }
 
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
-    bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
+    bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
+    typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
 static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
         const half2  * const __restrict__ V_h2,
-        const half2  * const __restrict__ mask_h2,
+        const half   * const __restrict__ mask_h,
         float2       * const __restrict__ dstk,
         float2       * const __restrict__ dstk_fixup,
         const float scale,
         const float slope,
         const float logit_softcap,
-        const int ne01,
+        const uint3 ne01,
         const int ne02,
         const int stride_K,
         const int stride_V,
@@ -412,27 +385,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         half2        * const __restrict__ tile_Q,
         half2        * const __restrict__ tile_K,
         half2        * const __restrict__ tile_V,
-        half2        * const __restrict__ tile_mask,
-        const tile_B * const __restrict__ Q_B,
-        tile_C_VKQ   * const __restrict__ VKQ_C,
+        half         * const __restrict__ tile_mask,
+        T_B_KQ       * const __restrict__ Q_B,
+        T_C_VKQ      * const __restrict__ VKQ_C,
         float        * const __restrict__ KQ_max,
         float        * const __restrict__ KQ_rowsum,
-        const int kb0) {
-#ifdef TURING_MMA_AVAILABLE
-    typedef fattn_mma_f16_config<DKQ, DV> c;
-
-#ifdef CP_ASYNC_AVAILABLE
-    constexpr int nstages = c::nstages_target;
-#else
-    constexpr int nstages = 0;
-#endif // CP_ASYNC_AVAILABLE
-
-    constexpr int cols_per_warp   = ntiles * tile_B::I;
-    constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
-    constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
-    constexpr int ncols           = ncols1 * ncols2;
-    constexpr int nbatch_K2       = c::get_nbatch_K2_device(ncols);
-    constexpr int nbatch_V2       = c::get_nbatch_V2_device(ncols);
+        const int jt,
+        const int kb0,
+        const int k_VKQ_sup) {
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+    constexpr int  ncols           = ncols1 * ncols2;
+    constexpr int  cols_per_warp   = T_B_KQ::I;
+    constexpr int  cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+    constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
+    constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
+    constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
+    constexpr bool Q_in_reg        = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
+    constexpr int  nstages         = ggml_cuda_fattn_mma_get_nstages  (DKQ, DV, ncols1, ncols2);
 
     constexpr int stride_tile_Q = DKQ/2     + 4;
     constexpr int stride_tile_K = nbatch_K2 + 4;
@@ -440,26 +410,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
     constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
 
-    const int k_VKQ_0 = kb0 * c::nbatch_fa;
-    tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
-
-    // Use wide variants of tiles if ntiles >= 2.
-    tile_B_16     * Q_B_16   = (tile_B_16     *) Q_B;
-    tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
-    tile_C_KQ_16  * KQ_C_16  = (tile_C_KQ_16  *) KQ_C;
+    const int k_VKQ_0 = kb0 * nbatch_fa;
+#if defined(TURING_MMA_AVAILABLE)
+    T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
+#else // Volta
+    T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
+#endif // defined(TURING_MMA_AVAILABLE)
 
     if constexpr (nstages > 1) {
+        static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
         static_assert(!mla, "multi-stage loading not implemented for MLA");
         static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
         constexpr bool use_cp_async = true;
         cp_async_wait_all();
         __syncthreads();
-        flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
-            (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
+        flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
+            (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup);
     } else {
         constexpr bool use_cp_async = nstages == 1;
-        if (ncols2 > 1 || mask_h2) {
-            flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
+        if (ncols2 > 1 || mask_h) {
+            flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
+                (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
         }
     }
 
@@ -468,10 +439,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
         const int k0_diff = k0_stop - k0_start;
 
-        if (nstages <= 1) {
+        if constexpr (nstages <= 1) {
             constexpr bool use_cp_async = nstages == 1;
-            flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
-                (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
+            flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
+                (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);
             if (use_cp_async) {
                 cp_async_wait_all();
             }
@@ -479,55 +450,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         }
 
         // Calculate tile of KQ:
-        if constexpr (c::Q_in_reg) {
+        if constexpr (Q_in_reg) {
 #pragma unroll
-            for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
-                const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
+            for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
+                const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
 #pragma unroll
-                for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
-                    tile_A K_A;
+                for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
+                    T_A_KQ K_A;
                     load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
-                    if (ntiles == 1) {
-                        mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
+                    if constexpr (cols_per_warp == 8) {
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
                     } else {
-#pragma unroll
-                        for (int t = 0; t < ntiles/2; ++t) {
-                            // Wide version of KQ_C is column-major => swap A and B.
-                            mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
-                        }
+                        // Wide version of KQ_C is column-major => swap A and B.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
                     }
                 }
             }
         } else {
-            static_assert(ntiles == 2, "ntiles != 2 not implemented");
+            static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
 #pragma unroll
-            for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
-                load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
+            for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
+                load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
 
 #pragma unroll
-                for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
-                    const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
+                for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
+                    const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
 
-                    tile_A K_A;
+                    T_A_KQ K_A;
                     load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
 
                     // Wide version of KQ_C is column-major => swap A and B.
-                    mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A);
+                    mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
                 }
             }
         }
 
-        if (nstages <= 1) {
+        if constexpr (nstages <= 1) {
             __syncthreads(); // Only needed if tile_K == tile_V.
         }
     }
 
     if (use_logit_softcap) {
-        static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
+        constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J;
+        static_assert(nbatch_fa % stride == 0, "bad loop size");
 #pragma unroll
-        for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
+        for (int i = 0; i < nbatch_fa/stride; ++i) {
 #pragma unroll
-            for (int l = 0; l < tile_C_KQ::ne; ++l) {
+            for (int l = 0; l < T_C_KQ::ne; ++l) {
                 KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
             }
         }
@@ -540,34 +509,35 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     }
     float KQ_rowsum_add[cols_per_thread] = {0.0f};
 
-    if (ntiles == 1) {
-        if (ncols2 > 1 || mask_h2) {
+    if constexpr (cols_per_warp == 8) {
+        if (ncols2 > 1 || mask_h) {
 #pragma unroll
-            for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
-                const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
+            for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) {
+                const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I;
 #pragma unroll
-                for (int l = 0; l < tile_C_KQ::ne; ++l) {
-                    const int i = i0 + tile_C_KQ::get_i(l);
-                    const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
+                for (int l = 0; l < T_C_KQ::ne; ++l) {
+                    const int i = i0 + T_C_KQ::get_i(l);
+                    const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2;
 
-                    KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
-                        __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
+                    KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
                 }
             }
         }
 
         // Calculate softmax for each KQ column using the current max. value.
         // The divisor is stored in KQ_rowsum and will be applied at the end.
-        static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
+        static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
 #pragma unroll
-        for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
+        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
 #pragma unroll
-            for (int l = 0; l < tile_C_KQ::ne; ++l) {
-                KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
+            for (int l = 0; l < T_C_KQ::ne; ++l) {
+                if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
+                    KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l]);
+                }
             }
         }
 
-        // Values per KQ column are spread across 8 threads, does not need full warp reduce:
+        // Values per KQ column are spread across 8 threads:
 #pragma unroll
         for (int col = 0; col < cols_per_thread; ++col) {
 #pragma unroll
@@ -576,73 +546,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             }
         }
 
-        static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
+        static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
 #pragma unroll
-        for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
+        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
 #pragma unroll
-            for (int l = 0; l < tile_C_KQ::ne; ++l) {
-                KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
-
-                KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
+            for (int l = 0; l < T_C_KQ::ne; ++l) {
+                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
+                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
+                    KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
+                } else {
+                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
+                }
             }
         }
-    } else { // ntiles > 1
-        if (ncols2 > 1 || mask_h2) {
+    } else { // not Turing mma or T_B_KQ::I > 8
+        if (ncols2 > 1 || mask_h) {
 #pragma unroll
-            for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) {
-                const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
+            for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
+                const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
 #pragma unroll
-                for (int t = 0; t < ntiles/2; ++t) {
-#pragma unroll
-                    for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
-                        const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
-                        const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
+                for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
+                    const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
+                    const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2;
 
-                        const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]);
-                        const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
-                        KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
-                        KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
-                    }
+                    const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]);
+                    KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
+                    KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
                 }
             }
         }
 
         // Calculate softmax for each KQ column using the current max. value.
         // The divisor is stored in KQ_rowsum and will be applied at the end.
-        static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
-#pragma unroll
-        for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
+        static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
 #pragma unroll
-            for (int t = 0; t < ntiles/2; ++t) {
+        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
 #pragma unroll
-                for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
-                    const int KQ_index = 2*t + (l/2) % 2;
-                    KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
+            for (int l = 0; l < T_C_KQ::ne; ++l) {
+                if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
+                    // Turing + Volta:
+                    KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l]);
                 }
             }
         }
 
-        // Values per KQ column are spread across 4 threads, does not need full warp reduce:
 #pragma unroll
         for (int col = 0; col < cols_per_thread; ++col) {
+#if defined(TURING_MMA_AVAILABLE)
+            // Values per KQ column are spread across 4 threads:
+            constexpr int offset_first = 2;
+            constexpr int offset_last  = 1;
+#else
+            // Values per KQ column are spread across 2 threads:
+            constexpr int offset_first = 2;
+            constexpr int offset_last  = 2;
+#endif // defined(TURING_MMA_AVAILABLE)
 #pragma unroll
-            for (int offset = 2; offset >= 1; offset >>= 1) {
+            for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
                 KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
             }
         }
 
-        static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size");
+        static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
 #pragma unroll
-        for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
+        for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
 #pragma unroll
-            for (int t = 0; t < ntiles/2; ++t) {
-#pragma unroll
-                for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
-                    const int KQ_index = 2*t + (l/2) % 2;
-
-                    KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
-
-                    KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
+            for (int l = 0; l < T_C_KQ::ne; ++l) {
+                // Turing + Volta:
+                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
+                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
+                    KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
+                } else {
+                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
                 }
             }
         }
@@ -662,12 +637,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
         }
 
-        if (ntiles == 1) {
+#if defined(TURING_MMA_AVAILABLE)
+        if constexpr (cols_per_warp == 8) {
             const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
 #pragma unroll
-            for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
+            for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
 #pragma unroll
-                for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+                for (int l = 0; l < T_C_VKQ::ne; ++l) {
                     VKQ_C[i].x[l] *= KQ_max_scale_h2;
                 }
             }
@@ -676,46 +652,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             for (int col = 0; col < cols_per_thread; ++col) {
                 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
 #pragma unroll
-                for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
+                for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
 #pragma unroll
-                    for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
-                        VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+                    for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
+                        VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
                     }
                 }
             }
         }
+#else // Volta
+        const half2 KQ_max_scale_h2 = make_half2(
+            KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
+#pragma unroll
+        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+            for (int l = 0; l < T_C_VKQ::ne; ++l) {
+                VKQ_C[i].x[l] *= KQ_max_scale_h2;
+            }
+        }
+#endif // defined(TURING_MMA_AVAILABLE)
     }
 
     // Convert KQ C tiles into B tiles for VKQ calculation:
-    tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
-    tile_B_16 * B_16 = (tile_B_16 *) B;
-    static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
-    if (ntiles == 1) {
+    T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)];
+    static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size");
+    if constexpr (cols_per_warp == 8) {
 #pragma unroll
-        for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
+        for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
             B[k] = get_transposed(get_half2(KQ_C[k]));
         }
     } else {
-        for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) {
-#pragma unroll
-            for (int t = 0; t < ntiles/2; ++t) {
-                B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
-            }
+        for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
+            B[k] = get_half2(KQ_C[k]);
         }
     }
 
-    if (nstages > 1) {
+    if constexpr (nstages > 1) {
         // Preload K tile for next iteration:
         constexpr bool use_cp_async = true;
         cp_async_wait_all();
         __syncthreads();
         if (!last_iter) {
-            if (ncols2 > 1 || mask_h2) {
-                flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
-                    (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
+            if (ncols2 > 1 || mask_h) {
+                flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
+                    (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
             }
-            flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
-                (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
+            flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
+                (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
         }
     }
 
@@ -724,72 +707,119 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     // Therefore, iterate over V in reverse and re-use the data if possible.
     static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
     constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
+
+    // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
 #pragma unroll
     for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
         const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
         const int i0_diff  = i0_stop - i0_start;
 
-        if (nstages <= 1 && i0_start < reusable_cutoff) {
-            constexpr bool use_cp_async = nstages == 1;
-            flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
-                (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
-            if (use_cp_async) {
-                cp_async_wait_all();
+        if constexpr (nstages <= 1) {
+            if (i0_start < reusable_cutoff) {
+                constexpr bool use_cp_async = nstages == 1;
+                flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
+                    (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
+                if (use_cp_async) {
+                    cp_async_wait_all();
+                }
+                __syncthreads();
             }
-            __syncthreads();
         }
         const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
 
-        // Calculate VKQ tile:
+#if defined(TURING_MMA_AVAILABLE)
+        constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
 #pragma unroll
-        for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) {
-            static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size");
+        for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
+            static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
 #pragma unroll
-            for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) {
-                const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
+            for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
+                const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
 
-                tile_A A;
+                T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
                 load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
-                if (ntiles == 1) {
-                    mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
+                if constexpr (T_B_KQ::I == 8) {
+                    mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
                 } else {
-#pragma unroll
-                    for (int t = 0; t < ntiles/2; ++t) {
-                        // Wide version of VKQ_C is column-major => swap A and B.
-                        mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
-                    }
+                    // Wide version of VKQ_C is column-major => swap A and B.
+                    mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
                 }
             }
         }
+#else // Volta
+        constexpr int i0_stride = 2*T_C_VKQ::J;
+#pragma unroll
+        for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
+            static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size");
+            static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes");
+#pragma unroll
+            for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) {
+                const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I;
+
+                T_A_VKQ A; // Transposed in both SRAM and registers, load normally.
+                load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+                mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
+            }
+        }
+#endif // defined(TURING_MMA_AVAILABLE)
 
-        if (nstages <= 1) {
+        if constexpr (nstages <= 1) {
             __syncthreads(); // Only needed if tile_K == tile_V.
         }
     }
 #else
-    GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup,
+    GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup,
         scale, slope, logit_softcap, ne01, ne02,
         stride_K, stride_V, stride_mask,
         tile_Q, tile_K, tile_V, tile_mask,
         Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     NO_DEVICE_CODE;
-#endif // TURING_MMA_AVAILABLE
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 }
 
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
+#if defined(TURING_MMA_AVAILABLE)
+template<int ncols> struct mma_tile_sizes {
+    using T_A_KQ  = tile<16,  8, half2>; // row-major
+    using T_B_KQ  = tile<16,  8, half2>; // column-major
+    using T_C_KQ  = tile<16, 16, float>; // column-major
+    using T_A_VKQ = tile<16,  8, half2>; // row-major
+    using T_B_VKQ = tile<16,  8, half2>; // column-major
+    using T_C_VKQ = tile<16,  8, half2>; // column-major
+};
+template<> struct mma_tile_sizes<8> {
+    using T_A_KQ  = tile<16,  8, half2>; // row-major
+    using T_B_KQ  = tile< 8,  8, half2>; // column-major
+    using T_C_KQ  = tile<16,  8, float>; // row-major
+    using T_A_VKQ = tile<16,  8, half2>; // row-major
+    using T_B_VKQ = tile< 8,  8, half2>; // column-major
+    using T_C_VKQ = tile<16,  4, half2>; // row-major
+};
+#else // Volta
+template<int ncols> struct mma_tile_sizes {
+    using T_A_KQ  = tile< 8,  4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
+    using T_B_KQ  = tile<32,  4, half2, DATA_LAYOUT_I_MAJOR>;          // column-major
+    using T_C_KQ  = tile<32,  8, float, DATA_LAYOUT_I_MAJOR>;          // column-major
+    using T_A_VKQ = tile< 8,  4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major
+    using T_B_VKQ = tile<32,  4, half2, DATA_LAYOUT_I_MAJOR>;          // column-major
+    using T_C_VKQ = tile<32,  4, half2, DATA_LAYOUT_I_MAJOR>;          // column-major
+};
+#endif // defined(TURING_MMA_AVAILABLE)
+
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
 static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
         const half2  * const __restrict__ V_h2,
-        const half2  * const __restrict__ mask_h2,
+        const half   * const __restrict__ mask_h,
         const float  * const __restrict__ sinks_f,
         float2       * const __restrict__ dstk,
         float2       * const __restrict__ dstk_fixup,
         const float scale,
         const float slope,
         const float logit_softcap,
-        const int ne01,
+        const uint3 ne01,
         const int ne02,
+        const int ne11,
         const int stride_Q1,
         const int stride_Q2,
         const int stride_K,
@@ -798,23 +828,31 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int jt,
         const int kb0_start,
         const int kb0_stop) {
-#ifdef TURING_MMA_AVAILABLE
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
-    typedef fattn_mma_f16_config<DKQ, DV> c;
-
-#ifdef CP_ASYNC_AVAILABLE
-    constexpr int nstages = c::nstages_target;
-#else
-    constexpr int nstages = 0;
-#endif // CP_ASYNC_AVAILABLE
-
-    constexpr int ncols           = ncols1 * ncols2;
-    constexpr int cols_per_warp   = ntiles * tile_B::I;
-    constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
-    constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
-    constexpr int nbatch_K2       = c::get_nbatch_K2_device(ncols);
-    constexpr int nbatch_V2       = c::get_nbatch_V2_device(ncols);
+    constexpr int ncols = ncols1 * ncols2;
+    using     T_A_KQ    = typename mma_tile_sizes<ncols>::T_A_KQ;
+    using     T_B_KQ    = typename mma_tile_sizes<ncols>::T_B_KQ;
+    using     T_C_KQ    = typename mma_tile_sizes<ncols>::T_C_KQ;
+    using     T_A_VKQ   = typename mma_tile_sizes<ncols>::T_A_VKQ;
+    using     T_B_VKQ   = typename mma_tile_sizes<ncols>::T_B_VKQ;
+    using     T_C_VKQ   = typename mma_tile_sizes<ncols>::T_C_VKQ;
+
+    constexpr int  cols_per_warp   = T_B_KQ::I;
+    constexpr int  cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+    constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols);
+    constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols);
+    constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2     (DKQ, DV, ncols);
+    constexpr int  nbatch_combine  = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols);
+    constexpr bool Q_in_reg        = ggml_cuda_fattn_mma_get_Q_in_reg      (DKQ, DV, ncols);
+    constexpr int  nstages         = ggml_cuda_fattn_mma_get_nstages       (DKQ, DV, ncols1, ncols2);
+
+    if (cols_per_warp > ncols) {
+        NO_DEVICE_CODE;
+        return;
+    }
 
     static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
 
@@ -826,15 +864,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
 
     extern __shared__ half2 tile_Q[];
-    half2 * tile_K    = c::Q_in_reg ? tile_Q                                : tile_Q + ncols        * stride_tile_Q;
-    half2 * tile_V    = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K;
-    half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max;
-
-    tile_B       Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles];
-    tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I  * ntiles];
+    half2 * tile_K    = Q_in_reg              ? tile_Q                             : tile_Q + ncols     * stride_tile_Q;
+    half2 * tile_V    =           nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K;
+    half  * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max);
 
-    tile_B_16     * Q_B_16   = (tile_B_16     *) Q_B;
-    tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
+    T_B_KQ    Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
+#if defined(TURING_MMA_AVAILABLE)
+    T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
+#else // Volta
+    T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
+#endif // defined(TURING_MMA_AVAILABLE)
 
     float KQ_rowsum[cols_per_thread] = {0.0f};
     float KQ_max[cols_per_thread];
@@ -868,7 +907,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             const int j = jc / ncols2;
             const int c = jc % ncols2;
 
-            if (jt*ncols1 + j < ne01) {
+            if (jt*ncols1 + j < int(ne01.z)) {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -889,63 +928,96 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
     __syncthreads();
 
-    if (c::Q_in_reg) {
+    if (Q_in_reg) {
         const int j0 = (threadIdx.y / np) * cols_per_warp;
 
 #pragma unroll
-        for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
-            if (ntiles == 1) {
-                load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
-            } else {
-#pragma unroll
-                for (int t = 0; t < ntiles/2; ++t) {
-                    load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
-                        tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
-                }
-            }
+        for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) {
+            load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
         }
     }
 
     __syncthreads();
 
+    int kb0 = kb0_start;
+
     // Preload mask and K data for first iteration when using cp_async with multiple stages:
     if constexpr (nstages > 1) {
         static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
         constexpr bool use_cp_async = true;
-        if (ncols2 > 1 || mask_h2) {
-            flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
-                (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
+        constexpr bool oob_check    = false;
+        constexpr int  k_VKQ_sup    = nbatch_fa;
+        if (ncols2 > 1 || mask_h) {
+            flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
+                (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
         }
-        flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
-            (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
+        flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
+            (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
     }
 
-    // Iterate over ne11 == previous tokens:
-    int kb0 = kb0_start;
     for (; kb0 < kb0_stop-1; ++kb0) {
         constexpr bool last_iter = false;
-        flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
-            (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
-    }
-    { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
+        constexpr bool oob_check = false;
+        constexpr int  k_VKQ_sup = nbatch_fa;
+        flash_attn_ext_f16_iter
+            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+             T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+            (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+             KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+    }
+    // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
+    if constexpr (ncols2 == 1) {
+        if (ne11 % nbatch_fa == 0) {
+            constexpr bool last_iter = true;
+            constexpr bool oob_check = false;
+            constexpr int  k_VKQ_sup = nbatch_fa;
+            flash_attn_ext_f16_iter
+                <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+                 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+                (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+                 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+                 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+        } else {
+            constexpr bool last_iter = true;
+            constexpr bool oob_check = true;
+            const     int  k_VKQ_sup = ne11 - kb0*nbatch_fa;
+            flash_attn_ext_f16_iter
+                <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+                 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+                (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+                 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+                 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+        }
+    } else {
         constexpr bool last_iter = true;
-        flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
-            (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+        constexpr bool oob_check = false;
+        constexpr int  k_VKQ_sup = nbatch_fa;
+        flash_attn_ext_f16_iter
+            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+             T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+            (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+             KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
     }
 
     // With multi-stage loading there is no __syncthreads at the end of the iter,
     //     there can be a race condition on shared memory access for combining/writing back results.
-    if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
+    if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) {
         __syncthreads();
     }
 
     // Finally, sum up partial KQ rowsums.
-    // The partial sums are spread across 8/4 threads each, does not need full reduce.
     {
-        constexpr int offset_first = ntiles == 1 ? 16 : 2;
-        constexpr int offset_last  = ntiles == 1 ?  4 : 1;
+#if defined(TURING_MMA_AVAILABLE)
+        // The partial sums are spread across 8/4 threads.
+        constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
+        constexpr int offset_last  = cols_per_warp == 8 ?  4 : 1;
+#else // Volta
+        // The partial sums are spread across 2 threads.
+        constexpr int offset_first = 2;
+        constexpr int offset_last  = 2;
+#endif // defined(TURING_MMA_AVAILABLE)
 #pragma unroll
         for (int col = 0; col < cols_per_thread; ++col) {
 #pragma unroll
@@ -962,8 +1034,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         float KQ_max_scale[cols_per_thread];
 #pragma unroll
         for (int col = 0; col < cols_per_thread; ++col) {
-            static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
-            const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
+            const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
             const float sink = sinks_f[jc % ncols2];
 
             const float KQ_max_new = fmaxf(KQ_max[col], sink);
@@ -977,12 +1048,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
         }
 
-        if (ntiles == 1) {
+#if defined(TURING_MMA_AVAILABLE)
+        if constexpr (cols_per_warp == 8) {
             const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
 #pragma unroll
-            for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
+            for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
 #pragma unroll
-                for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+                for (int l = 0; l < T_C_VKQ::ne; ++l) {
                     VKQ_C[i].x[l] *= KQ_max_scale_h2;
                 }
             }
@@ -991,30 +1063,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             for (int col = 0; col < cols_per_thread; ++col) {
                 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
 #pragma unroll
-                for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
+                for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
 #pragma unroll
-                    for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
-                        VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+                    for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
+                        VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
                     }
                 }
             }
         }
+#else // Volta
+        const int col = (threadIdx.x / 2) % 2;
+        const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+            for (int l = 0; l < T_C_VKQ::ne; ++l) {
+                VKQ_C[i].x[l] *= KQ_max_scale_h2;
+            }
+        }
+#endif // defined(TURING_MMA_AVAILABLE)
     }
 
     // Combine VKQ accumulator values if np > 1.
     // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
     // So also write VKQ accumulators to shared memory in column-major format if np == 1.
 
-    constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
-    constexpr int tile_stride    = nbatch_combine + 4;
+    constexpr int tile_stride = nbatch_combine + 4;
     static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
 
-    if constexpr (ntiles == 1) {
-        const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
-        const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
+    if constexpr (cols_per_warp == 8) {
+        const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset
+        const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
         const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
 
-        if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
+        if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) {
             // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
             ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
         }
@@ -1023,24 +1105,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
         if (np == 1) {
             // No combination is needed, the meta data can be directly written from registers to VRAM.
-            if (needs_fixup && threadIdx.x < tile_B::I) {
+            if (needs_fixup && threadIdx.x < T_B_KQ::I) {
                 float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
                 dstk_fixup_meta[jc_cwm] = KQ_cmr;
             }
-            if (is_fixup && threadIdx.x < tile_B::I) {
+            if (is_fixup && threadIdx.x < T_B_KQ::I) {
                 float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
                 dstk_fixup_meta[jc_cwm] = KQ_cmr;
             }
         }
     } else {
-        static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
-        const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
-            + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
-            + tile_C_VKQ_16::get_i(threadIdx.x % 4);
-        const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
-
-        if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
-            // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+        // jc_cwm = jc combine write meta
+        // KQ_cmr = KQ combine max rowsum
+        // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale.
+#if defined(TURING_MMA_AVAILABLE)
+        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
+        const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
+        const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
+#else // Volta
+        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
+        const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
+        const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8;
+#endif // defined(TURING_MMA_AVAILABLE)
+
+        if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) {
             ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
         }
 
@@ -1048,18 +1136,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
         if (np == 1) {
             // No combination is needed, the meta data can be directly written from registers to VRAM.
-            if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+            if (needs_fixup && thread_should_write) {
                 float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
                 dstk_fixup_meta[jc_cwm] = KQ_cmr;
             }
-            if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+            if (is_fixup && thread_should_write) {
                 float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
                 dstk_fixup_meta[jc_cwm] = KQ_cmr;
             }
         }
     }
 
-    static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
     if (np > 1 && threadIdx.y % np == 0) {
         // Combine the meta data for parallel warps via shared memory.
         // Warps with threadIdx.y % np != 0 must NOT return early.
@@ -1135,32 +1222,29 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #pragma unroll
     for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
-        if (ntiles == 1) {
-            const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
+        if constexpr (cols_per_warp == 8) {
+            const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
 #pragma unroll
-            for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
-                const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
+            for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
+                const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format.
 
 #pragma unroll
-                for (int l = 0; l < tile_B::ne; ++l) {
-                    const int k = k0 + tile_B::get_j(l);
+                for (int l = 0; l < T_B_KQ::ne; ++l) {
+                    const int k = k1 + T_B_KQ::get_j(l);
 
                     tile_Q[jc_cwd*tile_stride + k] = B.x[l];
                 }
             }
         } else {
+            const int j0 = threadIdx.y*cols_per_warp;
 #pragma unroll
-            for (int t = 0; t < ntiles/2; ++t) {
-                const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
+            for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
 #pragma unroll
-                for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) {
-#pragma unroll
-                    for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
-                        const int j = j0 + tile_C_VKQ_16::get_i(l);
-                        const int k = k0 + tile_C_VKQ_16::get_j(l);
+                for (int l = 0; l < T_C_VKQ::ne; ++l) {
+                    const int j = j0 + T_C_VKQ::get_i(l);
+                    const int k = k1 + T_C_VKQ::get_j(l);
 
-                        tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
-                    }
+                    tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
                 }
             }
         }
@@ -1195,7 +1279,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                     const int j_dst = jc_dst / ncols2;
                     const int c_dst = jc_dst % ncols2;
 
-                    if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
+                    if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
                         continue;
                     }
 
@@ -1233,16 +1317,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
     }
 #else
-    GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup,
+    GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
         scale, slope, logit_softcap, ne01, ne02,
         stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
         jt, kb0_start, kb0_stop);
     NO_DEVICE_CODE;
-#endif // TURING_MMA_AVAILABLE
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 }
 
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
-__launch_bounds__(nwarps*WARP_SIZE, 1)
+template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
+__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -1258,14 +1342,14 @@ static __global__ void flash_attn_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
                             const int32_t nb01, const int32_t nb02, const int32_t nb03,
         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
                             const int32_t nb11, const int32_t nb12, const int64_t nb13,
                             const int32_t nb21, const int32_t nb22, const int64_t nb23,
                             const int32_t ne31, const int32_t ne32, const int32_t ne33,
                             const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1281,23 +1365,22 @@ static __global__ void flash_attn_ext_f16(
 
     static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
 
-    typedef fattn_mma_f16_config<DKQ, DV> c;
-
-    static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config<DKQ, DV>::nbatch_fa == 0, "bad nbatch_fa");
+    constexpr int ncols     = ncols1 * ncols2;
+    constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
+    constexpr int nthreads  = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
+    constexpr int nwarps    = nthreads / WARP_SIZE;
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 
     const int stride_Q1   = nb01 / sizeof(float2);
     const int stride_Q2   = nb02 / sizeof(float2);
     const int stride_K    = nb11 / sizeof(half2);
-    const int stride_mask = nb31 / sizeof(half2);
+    const int stride_mask = nb31 / sizeof(half);
 
     const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
 
-    const int iter_k = ne11 / FATTN_KQ_STRIDE;
-    const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
-
-    constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
+    const int iter_k = (ne11   + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j = (ne01.z + (ncols1    - 1)) / ncols1;
 
     // kbc == k block continuous, current index in continuous ijk space.
     int       kbc      = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
@@ -1318,35 +1401,31 @@ static __global__ void flash_attn_ext_f16(
 
         const int head0 = zt * ncols2;
 
-        const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02* head0);
-        const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
-        const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
-            (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
-        float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
+        const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02* head0);
+        const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+        const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
+            (const half  *) (mask + nb33*(sequence % ne33));
+        float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
         const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
         const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
 
-        const int kb0_start_kernel = kb0_start * kb_niter;
-        int       kb0_stop_kernel  = kb0_stop  * kb_niter;
-
         if (KV_max) {
-            kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
+            kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
         }
-
         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         if (kb0_start == 0) {
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
-            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
-                (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
+                (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
+                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
         } else {
-            constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
-            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
-                (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+            constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
+            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
+                (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
+                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
         }
 
         kbc += iter_k;
@@ -1366,29 +1445,26 @@ static __global__ void flash_attn_ext_f16(
 
     const int head0 = zt * ncols2;
 
-    const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02* head0);
-    const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
-    const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
-        (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
-    float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
+    const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02* head0);
+    const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+    const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
+        (const half *) (mask + nb33*(sequence % ne33));
+    float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
     const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
     const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
 
-    const int kb0_start_kernel = kb0_start * kb_niter;
-    int       kb0_stop_kernel  = kb0_stop  * kb_niter;
-
     if (KV_max) {
-        kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
+        kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
     }
 
     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     constexpr bool needs_fixup = false;
-    flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
-        (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-         ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+    flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
+        (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
+         ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
 #else
     GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
         max_bias, m0, m1, n_head_log2, logit_softcap,
@@ -1400,7 +1476,7 @@ static __global__ void flash_attn_ext_f16(
               ne31, ne32, ne33,
               nb31, nb32, nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
 }
 
 template <int DKQ, int DV, int ncols1, int ncols2>
@@ -1409,36 +1485,30 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     const int id = ggml_cuda_get_device();
     const int cc = ggml_cuda_info().devices[id].cc;
 
-    typedef fattn_mma_f16_config<DKQ, DV> c;
+    constexpr int ncols = ncols1 * ncols2;
 
-    const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
+    const int  nthreads       = ggml_cuda_fattn_mma_get_nthreads      (DKQ, DV, ncols, cc);
+    const int  nbatch_fa      = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols, cc);
+    const int  nbatch_K2      = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols, cc);
+    const int  nbatch_V2      = ggml_cuda_fattn_mma_get_nbatch_V2     (DKQ, DV, ncols, cc);
+    const int  nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc);
+    const bool Q_in_reg       = ggml_cuda_fattn_mma_get_Q_in_reg      (DKQ, DV, ncols, cc);
+    const int  nstages        = ggml_cuda_fattn_mma_get_nstages       (DKQ, DV, ncols1, ncols2, cc);
 
-    constexpr int ncols         = ncols1 * ncols2;
-    constexpr int ntiles        = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
-    constexpr int cols_per_warp = ntiles * tile_B::I;
-    constexpr int nwarps_max_x  = ncols / cols_per_warp;
-    constexpr int nwarps_max_y  = c::nbatch_fa / tile_A::I;
-    constexpr int nwarps        = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
+    const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
+    const int nwarps        = nthreads / WARP_SIZE;
 
     constexpr bool mla = DKQ == 576;
 
-    const int nbatch_K2      = c::get_nbatch_K2_host     (cc, ncols);
-    const int nbatch_V2      = c::get_nbatch_K2_host     (cc, ncols);
-    const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
-
-    static_assert(DKQ   % tile_B::J     == 0, "bad DKQ");
-    static_assert(DV    % tile_A::J     == 0, "bad DV");
-    static_assert(ncols % cols_per_warp == 0, "bad ncols");
-
-    const size_t nbytes_shared_KV_1stage = c::nbatch_fa         * std::max(nbatch_K2 + 4,  nbatch_V2 + 4) * sizeof(half2);
-    const size_t nbytes_shared_KV_2stage = c::nbatch_fa         *         (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
+    const size_t nbytes_shared_KV_1stage = nbatch_fa            * std::max(nbatch_K2 + 4,  nbatch_V2 + 4) * sizeof(half2);
+    const size_t nbytes_shared_KV_2stage = nbatch_fa            *         (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
     const size_t nbytes_shared_Q         = ncols                * (DKQ/2 + 4)                             * sizeof(half2);
-    const size_t nbytes_shared_mask      = ncols1               * (c::nbatch_fa/2 + 4)                    * sizeof(half2);
+    const size_t nbytes_shared_mask      = ncols1               * (nbatch_fa/2 + 4)                       * sizeof(half2);
     const size_t nbytes_shared_combine   = nwarps*cols_per_warp * (nbatch_combine + 4)                    * sizeof(half2);
 
     const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
 
-    const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ?
+    const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ?
         std::max(nbytes_shared_Q,  nbytes_shared_KV + nbytes_shared_mask) :
                  nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
 
@@ -1448,7 +1518,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     fattn_kernel_t fattn_kernel;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
+        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
 
 #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1459,7 +1529,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
 #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
     } else {
         constexpr bool use_logit_softcap = true;
-        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
+        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
 
 #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1471,7 +1541,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     }
 
     launch_fattn<DV, ncols1, ncols2>
-        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
+        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
 }
 
 
index 3e58d64ff9d115f01c50cd1b6968bd7208276abc..63b235674eb8d19c38caca2c0994fcc8183f26bd 100644 (file)
@@ -501,6 +501,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
         const half2 * const __restrict__ K_h2,
         const half2 * const __restrict__ V_h2,
         const half  * const __restrict__ mask,
+        const uint3 ne01,
         const float logit_softcap,
         const float slope,
         T_KQ      * const KQ,
@@ -512,7 +513,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
         float * const KQ_sum,
         T_acc * const VKQ,
         const int k_VKQ_0,
-        const int k_VKQ_max) {
+        const int k_VKQ_max,
+        const int col_Q_0) {
     constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
     constexpr int cpy_ne = cpy_nb / 4;
 
@@ -556,7 +558,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
     // Apply logit softcap + mask, update KQ_max:
 #pragma unroll
     for (int jc0 = 0; jc0 < cpw; ++jc0) {
-        const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2;
+        const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01);
 
 #pragma unroll
         for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
@@ -736,7 +738,7 @@ static __global__ void flash_attn_tile(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
                             const int32_t nb01, const int32_t nb02, const int32_t nb03,
         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
                             const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -781,11 +783,11 @@ static __global__ void flash_attn_tile(
     const int sequence = blockIdx.z / (ne02/ncols2);
     const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f  = (const float *) (Q + nb03*sequence + nb02* head0              + nb01*col_Q_0);
+    const float * Q_f  = (const float *) (Q + nb03*sequence + nb02* head0);
     const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
     const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
 
-    const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr;
+    const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;
 
     const int stride_K2   = nb11 / sizeof(half2);
     const int stride_V2   = nb21 / sizeof(half2);
@@ -842,11 +844,9 @@ static __global__ void flash_attn_tile(
         for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
             if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
                 float tmp_f[cpy_ne_D] = {0.0f};
-                if (ncols1 == 1 || col_Q_0 + j < ne01) {
-                    ggml_cuda_memcpy_1<sizeof(tmp_f)>
-                        (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float))
-                                     + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
-                }
+                ggml_cuda_memcpy_1<sizeof(tmp_f)>
+                    (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
+                                 + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
 
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -881,23 +881,23 @@ static __global__ void flash_attn_tile(
         while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
             constexpr bool oob_check = false;
             flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
-                (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
-                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
+                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
+                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
             k_VKQ_0 += gridDim.y*nbatch_fa;
         }
         if (k_VKQ_0 < k_VKQ_max) {
             constexpr bool oob_check = true;
             flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
-                (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
-                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
+                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
+                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
         }
     } else {
         // Branch without out-of-bounds checks.
         for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
             constexpr bool oob_check = false;
             flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
-                (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
-                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
+                (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
+                stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
         }
     }
 
@@ -1010,13 +1010,13 @@ static __global__ void flash_attn_tile(
         const int j = jc / ncols2;
         const int c = jc % ncols2;
 
-        if (ncols1 > 1 && col_Q_0 + j >= ne01) {
+        if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) {
             return;
         }
 
         const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
 
-        const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
+        const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
 
 #ifdef FAST_FP16_AVAILABLE
         constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
index 67aa67ecb9426cefc923075689a383c05431d81b..0bae9849a96fb306e4388ef71bdf55ec577686e8 100644 (file)
@@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
                             const int32_t nb01, const int32_t nb02, const int32_t nb03,
         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
                             const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -150,7 +150,7 @@ static __global__ void flash_attn_ext_vec(
             float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
 
             // Set memory to zero if out of bounds:
-            if (ncols > 1 && ic0 + j >= ne01) {
+            if (ncols > 1 && ic0 + j >= int(ne01.z)) {
 #pragma unroll
                 for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
                     const int i = i0 + threadIdx.x;
@@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec(
                 const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
 
                 float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
-                if (ncols == 1 || ic0 + j < ne01) {
+                if (ncols == 1 || ic0 + j < int(ne01.z)) {
                     ggml_cuda_memcpy_1<cpy_nb>(tmp,            &Q_j[i]);
                     ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
                 }
@@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec(
 #pragma unroll
             for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
                 const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
-                if (ncols == 1 || ic0 + j < ne01) {
+                if (ncols == 1 || ic0 + j < int(ne01.z)) {
                     ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ],            &Q_j[i]);
                     ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
                 }
@@ -266,7 +266,7 @@ static __global__ void flash_attn_ext_vec(
                     sum = logit_softcap*tanhf(sum);
                 }
 
-                if (mask) {
+                if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) {
                     sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
                 }
 
@@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec(
 
 #pragma unroll
     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
-        if (ncols > 1 && ic0 + j_VKQ >= ne01) {
+        if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) {
             break;
         }
 
@@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec(
                 if (gridDim.y == 1) {
                     dst_val /= KQ_sum[j_VKQ];
                 }
-                dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
+                dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
             }
         }
 
@@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec(
 
     }
 
-    if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) {
-        dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
+    if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) {
+        dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
     }
 #else
     GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
index 6c90d6d52b3351b36433749614477ddd1aee3410..0d81f0aae0a75bbd192e8fa1e029ae23a13f81df 100644 (file)
@@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+        const int32_t ne00, const uint3   ne01, const int32_t ne02, const int32_t ne03,
                             const int32_t nb01, const int32_t nb02, const int32_t nb03,
         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
                             const int32_t nb11, const int32_t nb12, const int64_t nb13,
                             const int32_t nb21, const int32_t nb22, const int64_t nb23,
                             const int32_t ne31, const int32_t ne32, const int32_t ne33,
                             const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16(
             if (i0 + warp_size > D && i >= D) {
                 break;
             }
-            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+            KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f;
         }
     }
 
@@ -218,7 +218,8 @@ static __global__ void flash_attn_ext_f16(
                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
+                        __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
                     KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
                 }
                 KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
@@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16(
                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
                     KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
                 }
                 KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
@@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j_VKQ = j0 + threadIdx.y;
-        if (ic0 + j_VKQ >= ne01) {
+        if (ic0 + j_VKQ >= int(ne01.z)) {
             return;
         }
 
@@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16(
             KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
         }
 
-        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+        const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
 
 #pragma unroll
         for (int i0 = 0; i0 < D; i0 += warp_size) {
@@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16(
               ne31, ne32, ne33,
               nb31, nb32, nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
 }
 
 constexpr int get_max_power_of_2(int x) {
index 7235f1b77aeedcc426a8793a8f456461be28bc3d..cd3bfd4051a40b4d5da00d3ec8f0a83aa8a0da9e 100644 (file)
@@ -2,9 +2,9 @@
 
 #include "common.cuh"
 
-#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+#if defined(GGML_USE_MUSA)
 #define GGML_USE_WMMA_FATTN
-#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+#endif // defined(GGML_USE_MUSA)
 
 #if defined(GGML_HIP_ROCWMMA_FATTN)
 #if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
index 82405991cea6e5f02aa2958d2510f1be161278d6..dec01ff8ad2a0ccdebace0aa39650c5dde56105a 100644 (file)
@@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
     const ggml_tensor * Q = dst->src[0];
 
     if constexpr (ncols2 <= 8) {
-        if (Q->ne[1] <= 8/ncols2) {
+        if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
             ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
             return;
         }
     }
 
-    if (Q->ne[1] <= 16/ncols2) {
+    if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
         ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
         return;
     }
@@ -41,7 +41,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
     float max_bias = 0.0f;
     memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
 
-    const bool use_gqa_opt = mask && max_bias == 0.0f;
+    const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
 
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
     const int gqa_ratio = Q->ne[2] / K->ne[2];
@@ -275,8 +275,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
 
-    // If Turing tensor cores available, use them:
-    if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
+    // If Turing tensor cores are available, use them:
+    if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
         if (can_use_vector_kernel) {
             if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
@@ -297,7 +297,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
                 return BEST_FATTN_KERNEL_VEC;
             }
         }
+        return BEST_FATTN_KERNEL_MMA_F16;
+    }
 
+    if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
+        int gqa_ratio_eff = 1;
+        const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+        while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
+            gqa_ratio_eff *= 2;
+        }
+        if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
+            return BEST_FATTN_KERNEL_VEC;
+        }
+        if (Q->ne[1] * gqa_ratio_eff <= 16) {
+            return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
+        }
         return BEST_FATTN_KERNEL_MMA_F16;
     }
 
index 0ed42e87d3d504cd6c8956d87b3969c20a18c480..6ea7a809a479079cad74292e318e4465128aa7cf 100644 (file)
@@ -68,10 +68,31 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
 
 namespace ggml_cuda_mma {
 
+    // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
+    //     effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
+    // In those cases the data can be split in different ways across the warp.
+    enum data_layout {
+        // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
+        // For the A/C matrices this means I major == row major, J major == column major.
+        // For the B matrix this means I major == column major, J major == row major.
+        // MIRRORED == Each data value is held exactly once per thread subgroup.
+        DATA_LAYOUT_I_MAJOR           =  0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
+        DATA_LAYOUT_I_MAJOR_MIRRORED  = 10,
+        DATA_LAYOUT_J_MAJOR_MIRRORED  = 20,
+    };
+    // Implemented mma combinations are:
+    //   - (I_MAJOR, I_MAJOR)          -> I_MAJOR
+    //   - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
+    //   - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
+
+    template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
+    struct tile {};
+
     template <int I_, int J_, typename T>
-    struct tile {
-        static constexpr int I  = I_;
-        static constexpr int J  = J_;
+    struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
 
 #if defined(AMD_MFMA_AVAILABLE)
         static constexpr int ne = I * J / 64;
@@ -131,9 +152,9 @@ namespace ggml_cuda_mma {
         static __device__ __forceinline__ int get_i(const int l) {
             if constexpr (I == 32 && J == 8) {
 #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
-                return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
+                return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
 #else
-                return (l & 2) | (threadIdx.x & ~2);
+                return (l & 2) + (threadIdx.x & ~2);
 #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
             } else {
                 NO_DEVICE_CODE;
@@ -143,7 +164,7 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 32 && J == 8) {
-                return (threadIdx.x & 2) | (l & (4 + 1));
+                return (threadIdx.x & 2) + (l & (4 + 1));
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -196,9 +217,9 @@ namespace ggml_cuda_mma {
             } else if constexpr (I == 8 && J == 8) {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 8) {
-                return ((l / 2) * 8) | (threadIdx.x / 4);
+                return ((l / 2) * 8) + (threadIdx.x / 4);
             } else if constexpr (I == 16 && J == 16) {
-                return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
+                return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
             } else if constexpr (I == 32 && J == 8) {
                 return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
             } else {
@@ -211,11 +232,11 @@ namespace ggml_cuda_mma {
             if constexpr (I == 8 && J == 4) {
                 return threadIdx.x % 4;
             } else if constexpr (I == 8 && J == 8) {
-                return (l * 4) | (threadIdx.x % 4);
+                return (l * 4) + (threadIdx.x % 4);
             } else if constexpr (I == 16 && J == 8) {
-                return ((threadIdx.x % 4) * 2) | (l % 2);
+                return ((threadIdx.x % 4) * 2) + (l % 2);
             } else if constexpr (I == 16 && J == 16) {
-                return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
+                return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
             } else if constexpr (I == 32 && J == 8) {
                 return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
             } else {
@@ -227,26 +248,24 @@ namespace ggml_cuda_mma {
     };
 
     template <int I_, int J_>
-    struct tile<I_, J_, half2> {
-        static constexpr int I  = I_;
-        static constexpr int J  = J_;
+    struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
 
 #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
-        static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
+        static constexpr int ne = I * J / WARP_SIZE;
         half2 x[ne] = {{0.0f, 0.0f}};
 
         static constexpr __device__ bool supported() {
-            if (I ==  8 && J ==  8) return true;
-            if (I == 32 && J ==  8) return true;
+            if (I == 32 && J ==  4) return true;
             return false;
         }
 
         static __device__ __forceinline__ int get_i(const int l) {
-            if constexpr (I == 8 && J == 8) {
-                return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
-            } else if constexpr (I == 32 && J == 8) {
+            if constexpr (I == 32 && J == 4) {
 #ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
-                return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
+                return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
 #else
                 return threadIdx.x;
 #endif // GGML_CUDA_MMA_NO_VOLTA_PERM
@@ -257,7 +276,7 @@ namespace ggml_cuda_mma {
         }
 
         static __device__ __forceinline__ int get_j(const int l) {
-            if constexpr ((I == 8 || I == 32) && J == 8) {
+            if constexpr (I == 32 && J == 4) {
                 return l;
             } else {
                 NO_DEVICE_CODE;
@@ -307,11 +326,11 @@ namespace ggml_cuda_mma {
             if constexpr (I == 8 && J == 8) {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 4) {
-                return (l * 8) | (threadIdx.x / 4);
+                return (l * 8) + (threadIdx.x / 4);
             } else if constexpr (I == 16 && J == 8) {
-                return ((l % 2) * 8) | (threadIdx.x / 4);
+                return ((l % 2) * 8) + (threadIdx.x / 4);
             } else if constexpr (I == 32 && J == 8) {
-                return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
+                return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -320,13 +339,13 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 8 && J == 8) {
-                return (l * 4) | (threadIdx.x % 4);
+                return (l * 4) + (threadIdx.x % 4);
             } else if constexpr (I == 16 && J == 4) {
                 return threadIdx.x % 4;
             } else if constexpr (I == 16 && J == 8) {
-                return ((l / 2) * 4) | (threadIdx.x % 4);
+                return ((l / 2) * 4) + (threadIdx.x % 4);
             } else if constexpr (I == 32 && J == 8) {
-                return ((l & 2) * 2) | (threadIdx.x % 4);
+                return ((l & 2) * 2) + (threadIdx.x % 4);
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -336,14 +355,15 @@ namespace ggml_cuda_mma {
     };
 
     template <int I_, int J_>
-    struct tile<I_, J_, nv_bfloat162> {
-        static constexpr int I  = I_;
-        static constexpr int J  = J_;
+    struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
+        static constexpr int         ne = I * J / WARP_SIZE;
 
-#if defined(AMD_WMMA_AVAILABLE)
-        static constexpr int ne = I * J / 32;
         nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 
+#if defined(AMD_WMMA_AVAILABLE)
         static constexpr __device__ bool supported() {
             if (I == 16 && J == 8) return true;
             return false;
@@ -367,9 +387,6 @@ namespace ggml_cuda_mma {
             }
         }
 #else
-        static constexpr int ne = I * J / WARP_SIZE;
-        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
-
         static constexpr __device__ bool supported() {
             if (I ==  8 && J ==  8) return true;
             if (I == 16 && J ==  4) return true;
@@ -381,9 +398,9 @@ namespace ggml_cuda_mma {
             if constexpr (I == 8 && J == 8) {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 4) {
-                return (l * 8) | (threadIdx.x / 4);
+                return (l * 8) + (threadIdx.x / 4);
             } else if constexpr (I == 16 && J == 8) {
-                return ((l % 2) * 8) | (threadIdx.x / 4);
+                return ((l % 2) * 8) + (threadIdx.x / 4);
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -392,11 +409,11 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 8 && J == 8) {
-                return (l * 4) | (threadIdx.x % 4);
+                return (l * 4) + (threadIdx.x % 4);
             } else if constexpr (I == 16 && J == 4) {
                 return threadIdx.x % 4;
             } else if constexpr (I == 16 && J == 8) {
-                return ((l / 2) * 4) | (threadIdx.x % 4);
+                return ((l / 2) * 4) + (threadIdx.x % 4);
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -405,6 +422,73 @@ namespace ggml_cuda_mma {
 #endif  // defined(AMD_WMMA_AVAILABLE)
     };
 
+    template <int I_, int J_>
+    struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+        static constexpr int         ne = I * J / (WARP_SIZE/4);
+
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            if (I ==  8 && J ==  4) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int /*l*/) {
+            if constexpr (I == 8 && J == 4) {
+                return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 4) {
+                return l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+    };
+
+    template <int I_, int J_>
+    struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
+        static constexpr int         ne = I * J / (WARP_SIZE/4);
+
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            if (I ==  8 && J ==  4) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && J == 4) {
+                return ((l / 2) * 4) + (threadIdx.x % 4);
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 4) {
+                return ((threadIdx.x / 16) * 2) + (l % 2);
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+    };
+
+#if defined(TURING_MMA_AVAILABLE)
     template <int I, int J>
     static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
         tile<I, J/2, half2> ret;
@@ -422,9 +506,26 @@ namespace ggml_cuda_mma {
 
         return ret;
     }
+#else // Volta
+    template <int I, int J>
+    static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
+        tile<I, J/2, half2> ret;
+#pragma unroll
+        for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
+            ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+            ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
+
+            // On Volta FP16 and FP32 tiles have a different memory layout,
+            //     for the conversion threads with an offset of 2 need to exchange half their values:
+            ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
+                0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
+        }
+        return ret;
+    }
+#endif // defined(TURING_MMA_AVAILABLE)
 
-    template <int I, int J, typename T>
-    static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
+    template <int I, int J, typename T, data_layout dl>
+    static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
 #if defined(AMD_MFMA_AVAILABLE)
         if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
 #pragma unroll
@@ -511,18 +612,6 @@ namespace ggml_cuda_mma {
             : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
             : "l"(xs));
 #else
-#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
-        GGML_UNUSED_VARS(t, xs0, stride);
-        NO_DEVICE_CODE;
-#else
-        load_generic(t, xs0, stride);
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
-#endif // TURING_MMA_AVAILABLE
-    }
-
-    template <typename T>
-    static __device__ __forceinline__ void load_ldmatrix(
-            tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
 #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 #if 1
         // TODO: more generic handling
@@ -533,9 +622,31 @@ namespace ggml_cuda_mma {
         load_generic(t, xs0, stride);
 #endif // 1
 #else
-        tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
-        load_ldmatrix(t16[0], xs0 +  0*stride, stride);
-        load_ldmatrix(t16[1], xs0 + 16*stride, stride);
+        load_generic(t, xs0, stride);
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#endif // TURING_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
+        ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
+    }
+
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
+#pragma unroll
+        for (int l0 = 0; l0 < t.ne; l0 += 2) {
+            ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
+        }
+    }
+
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+        ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
+#else
+        GGML_UNUSED_VARS(t, xs0, stride);
+        NO_DEVICE_CODE;
 #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
     }
 
@@ -860,14 +971,14 @@ namespace ggml_cuda_mma {
     template <typename T1, typename T2, int J, int K>
     static __device__ __forceinline__ void mma(
             tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
-        tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
-        tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
+        tile      <16, J, T1> * D16 = reinterpret_cast<      tile<16, J, T1> *>(&D);
+        const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
         mma(D16[0], A16[0], B);
         mma(D16[1], A16[1], B);
     }
 
     static __device__ __forceinline__ void mma(
-            tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
+            tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
 #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
@@ -880,20 +991,30 @@ namespace ggml_cuda_mma {
             "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
             : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
             : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
-        asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
-            "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
-            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
-            : "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
-        asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
-            "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
-            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
-            : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
 #else
-        tile      <16, 8, float> * D16 = reinterpret_cast<tile      <16, 8, float> *>(&D);
-        const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
-        mma(D16[0], A16[0], B);
-        mma(D16[1], A16[1], B);
-#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+        asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
+            "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
+            "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
     }
 
 static __device__ __forceinline__ void mma(
index c2a0a2e42fe2fd56db34f553f3675d39861ae11a..e1c695c5c0f96a18f4a0124f3b546a1ed4bbc82d 100644 (file)
@@ -37,23 +37,19 @@ static __global__ void mul_mat_f(
     typedef tile<16,       8, T>     tile_A;
     typedef tile<tile_B_I, 8, T>     tile_B;
     typedef tile<16,       tile_C_J, float> tile_C;
-
-    constexpr bool a_supported = tile_A::supported();
-    constexpr bool b_supported = tile_B::supported();
-    constexpr bool c_supported = tile_C::supported();
-    constexpr bool supported = a_supported && b_supported && c_supported;
 #else
-    constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
-    constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
-    constexpr bool supported = I_16_supported || I_32_supported;
-
-    constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
-
-    typedef tile<I_preferred, 8, T>     tile_A;
-    typedef tile<8,           8, T>     tile_B;
-    typedef tile<I_preferred, 8, float> tile_C;
+#ifdef VOLTA_MMA_AVAILABLE
+    if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
+    typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;
+    typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
+    typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;
+#else
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile<8,  8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+#endif // VOLTA_MMA_AVAILABLE
 #endif // defined(AMD_WMMA_AVAILABLE)
-    if constexpr (!supported) {
+    if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
         NO_DEVICE_CODE;
         return;
     }
@@ -248,6 +244,9 @@ static __global__ void mul_mat_f(
             }
         }
     }
+#ifdef VOLTA_MMA_AVAILABLE
+    }
+#endif //VOLTA_MMA_AVAILABLE
 #else
     GGML_UNUSED_VARS(x, y, ids, dst,
         ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
@@ -278,27 +277,24 @@ static __global__ void mul_mat_f_ids(
     typedef tile<16,       8, T>     tile_A;
     typedef tile<tile_B_I, 8, T>     tile_B;
     typedef tile<16,       tile_C_J, float> tile_C;
-
-    constexpr bool a_supported = tile_A::supported();
-    constexpr bool b_supported = tile_B::supported();
-    constexpr bool c_supported = tile_C::supported();
-    constexpr bool supported = a_supported && b_supported && c_supported;
 #else
-    constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
-    constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
-    constexpr bool supported = I_16_supported || I_32_supported;
-
-    constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
-
-    typedef tile<I_preferred, 8, T>     tile_A;
-    typedef tile<8,           8, T>     tile_B;
-    typedef tile<I_preferred, 8, float> tile_C;
+#ifdef VOLTA_MMA_AVAILABLE
+    if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
+    typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;
+    typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
+    typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;
+#else
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile<8,  8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+#endif // VOLTA_MMA_AVAILABLE
 #endif // defined(AMD_WMMA_AVAILABLE)
-    if constexpr (!supported) {
+    if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
         NO_DEVICE_CODE;
         return;
     }
 
+
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int tile_k_padded = warp_size + 4;
     constexpr int ntA = rows_per_block / tile_A::I;
@@ -517,6 +513,9 @@ static __global__ void mul_mat_f_ids(
             }
         }
     }
+#ifdef VOLTA_MMA_AVAILABLE
+    }
+#endif // VOLTA_MMA_AVAILABLE
 #else
     GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
         ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,