]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: use async data loading for FlashAttention (llama/11894)
authorJohannes Gäßler <redacted>
Mon, 17 Feb 2025 13:03:24 +0000 (14:03 +0100)
committerGeorgi Gerganov <redacted>
Thu, 27 Feb 2025 06:55:36 +0000 (08:55 +0200)
* CUDA: use async data loading for FlashAttention

---------

Co-authored-by: Diego Devesa <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/cp-async.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-cuda/mma.cuh
ggml/src/ggml-cuda/mmq.cuh

index fd4dcfa941d4bb7a3aa4dce36a3c0eff8953f788..4a92d35f9f40ca1ab191ee5a336cc632168c4423 100644 (file)
 #define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
 #define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
 
-#define GGML_CUDA_CC_PASCAL     600
-#define GGML_CUDA_CC_DP4A       610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
-#define GGML_CUDA_CC_VOLTA      700
-#define GGML_CUDA_CC_TURING     750
-#define GGML_CUDA_CC_AMPERE     800
-#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
+#define GGML_CUDA_CC_PASCAL       600
+#define GGML_CUDA_CC_DP4A         610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define GGML_CUDA_CC_VOLTA        700
+#define GGML_CUDA_CC_TURING       750
+#define GGML_CUDA_CC_AMPERE       800
+#define GGML_CUDA_CC_ADA_LOVELACE 890
+#define GGML_CUDA_CC_OFFSET_AMD   0x1000000
 
 // GCN/CNDA, wave size is 64
 #define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 0x803)  // Tonga, Fiji, Polaris, minimum for fast fp16
@@ -199,6 +200,10 @@ typedef float2 dfloat2;
 #define NEW_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#define CP_ASYNC_AVAILABLE
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+
 #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
 #define FLASH_ATTN_AVAILABLE
 #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) {
     return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
 }
 
+static bool cp_async_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+}
+
 static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
     return __AMDGCN_WAVEFRONT_SIZE;
diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh
new file mode 100644 (file)
index 0000000..51aa41e
--- /dev/null
@@ -0,0 +1,46 @@
+// Simplified API for asynchronous data loading.
+
+#include "common.cuh"
+
+// Copies data from global to shared memory, cg == cache global.
+// Both the src and dst pointers must be aligned to 16 bit.
+// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
+// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
+// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
+template <int preload>
+static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
+    static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
+#ifdef CP_ASYNC_AVAILABLE
+#if CUDART_VERSION >= 11040
+    if (preload == 256) {
+        asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    } else if (preload == 128) {
+        asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    } else if (preload == 64) {
+        asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    } else
+#endif // CUDART_VERSION >= 11040
+    {
+        asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    }
+#else
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src);
+    NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
+
+// Makes each thread wait until its asynchronous data copies are done.
+// This does NOT provide any additional synchronization.
+// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
+static __device__ __forceinline__ void cp_async_wait_all() {
+#ifdef CP_ASYNC_AVAILABLE
+    asm volatile("cp.async.wait_all;");
+#else
+    NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
index d40ee2da4188799bc8d52c3d9a565ed5b7c0eebf..fefbd319baf76ff1637aa400215aee7efb2e2884 100644 (file)
@@ -716,7 +716,9 @@ void launch_fattn(
 
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
-    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+    const int id  = ggml_cuda_get_device();
+    const int cc  = ggml_cuda_info().devices[id].cc;
+    const int nsm = ggml_cuda_info().devices[id].nsm;
 
     ggml_cuda_pool_alloc<half>   K_f16(pool);
     ggml_cuda_pool_alloc<half>   V_f16(pool);
@@ -768,13 +770,14 @@ void launch_fattn(
     dim3 blocks_num;
     if (parallel_blocks == 0) {
         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
-        const int tiles_nwaves  = (ntiles_total - nsm - 1) / nsm;
-        const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
-        const bool short_context = K->ne[1] < 4096;
+        const int tiles_nwaves  = (ntiles_total + 2*nsm - 1) / (2*nsm);
+        const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
 
         const int nblocks_stream_k = 2*nsm;
 
-        blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
+        const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
+
+        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
         blocks_num.y = 1;
         blocks_num.z = 1;
 
@@ -827,7 +830,7 @@ void launch_fattn(
     CUDA_CHECK(cudaGetLastError());
 
     if constexpr (parallel_blocks == 0) {
-        if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
             const dim3 block_dim_combine(D, 1, 1);
             const dim3 blocks_num_combine = blocks_num;
 
index 05bc91a3b8d6f6a5b95a9d5b59ce1bb4eeabd31b..d777f5413ed97277a0edaab46a3b8a188b8b646d 100644 (file)
 #include "common.cuh"
+#include "cp-async.cuh"
 #include "mma.cuh"
 #include "fattn-common.cuh"
 
-template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
-static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
-        const float2 * const __restrict__ Q_f2,
-        const half2  * const __restrict__ K_h2,
-        const half2  * const __restrict__ V_h2,
-        const half   * const __restrict__ maskh,
-        float2       * const __restrict__ dstk,
-        float2       * const __restrict__ dstk_fixup,
-        const float scale,
-        const float slope,
-        const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int nb31,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3,
-        const int jt,
-        const int kb0_start,
-        const int kb0_stop) {
-#ifdef NEW_MMA_AVAILABLE
-    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+using namespace ggml_cuda_mma;
 
-    typedef mma_A_I16K8<half2> mma_A;
-    typedef mma_B_J8K8<half2>  mma_B;
-    typedef mma_C_I16J8<float> mma_C_KQ;
-    typedef mma_C_I16J8<half2> mma_C_VKQ;
-
-    static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
-    constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
-
-    static_assert(D         % nwarps == 0, "bad D");
-    static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
+typedef tile<16, 8, half2> tile_A;
+typedef tile< 8, 8, half2> tile_B;
+typedef tile<16, 8, float> tile_C_KQ;
+typedef tile<16, 4, half2> tile_C_VKQ;
 
+template<int D, int nwarps, int KQ_stride>
+static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
+        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
     constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
-    extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
 
-    const int stride_Q    = nb01 / sizeof(float2);
-    const int stride_KV   = nb11 / sizeof(half2);
-    const int stride_mask = nb31 / sizeof(half);
+    // If cp.async is available, load up to the highest power of 2 in D asynchronously:
+#ifdef CP_ASYNC_AVAILABLE
+    static_assert(D >= 64 && D < 512, "bad D");
+    constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
 
-    mma_B Q_B[D/(2*mma_B::K)];
-    mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
+    const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
 
-    float2    KQ_rowsum = {0.0f, 0.0f};
-    float2       KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
-    float2 KQ_max_scale = {0.0f, 0.0f};
+    constexpr int preload = 64;
+    constexpr int h2_per_chunk = 16/sizeof(half2);
+    constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
+    constexpr int stride_i = WARP_SIZE / chunks_per_row;
+#pragma unroll
+    for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
+        const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
+        const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
 
-    // Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
-    // The loading is done with decreasing granularity for D for better memory bandwidth.
-    const half2 scale_h2 = make_half2(scale, scale);
+        cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
+    }
+#else
+    constexpr int k0_sync_start = 0;
+#endif // CP_ASYNC_AVAILABLE
+    static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
+
+    // If D is not a power of 2, the rest is loaded synchronously.
+    // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
+    static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
 #pragma unroll
     for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-        const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
-        const int k0_stop  =                             D/2 - (D/2) % (1*stride_k);
-        const int stride_j = WARP_SIZE / stride_k;
+        const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
+        const int k0_stop  =                                         D/2 - (D/2) % (1*stride_k);
+        const int stride_i = WARP_SIZE / stride_k;
 
-        if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
-            break;
+        if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
+            continue;
         }
 
 #pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
-            const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+        for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
+            const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
 
-            if (jt*ncols + j < ne01) {
-#pragma unroll
-                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
-
-                    const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
-                    tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
-                }
-            } else {
 #pragma unroll
-                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+            for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
-                    tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
-                }
+                tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
             }
         }
     }
+}
 
-    __syncthreads();
-
-    {
-        const int j0 = (threadIdx.y / np) * mma_B::J;
+template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
+static __device__ __forceinline__ void flash_attn_ext_f16_iter(
+        const float2 * const __restrict__ Q_f2,
+        const half2  * const __restrict__ K_h2,
+        const half2  * const __restrict__ V_h2,
+        const half   * const __restrict__ maskh,
+        float2       * const __restrict__ dstk,
+        float2       * const __restrict__ dstk_fixup,
+        const float scale,
+        const float slope,
+        const float logit_softcap,
+        const int ne01,
+        const int ne02,
+        const int stride_Q,
+        const int stride_KV,
+        const int stride_mask,
+        const int jt,
+        half2        * const __restrict__ tile_K,
+        half2        * const __restrict__ tile_V,
+        const tile_B * const __restrict__ Q_B,
+        tile_C_VKQ   * const __restrict__ VKQ_C,
+        float2 & KQ_max,
+        float2 & KQ_rowsum,
+        const int kb0) {
+#ifdef NEW_MMA_AVAILABLE
+    constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
+    constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
 
-#pragma unroll
-        for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
-            Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
-        }
-    }
+    const int k_VKQ_0 = kb0*KQ_stride;
+    tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)];
 
+#ifdef CP_ASYNC_AVAILABLE
+    cp_async_wait_all();
     __syncthreads();
+    flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
+#else
+    flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
+    __syncthreads();
+#endif // CP_ASYNC_AVAILABLE
 
-    // Iterate over ne11 == previous tokens:
-    for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
-        const int k_VKQ_0 = kb0*KQ_stride;
-        mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
-
-        // Load K data into tile with decreasing granularity for D for better memory bandwidth:
-        static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
-#pragma unroll
-        for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-            const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
-            const int k0_stop  =                             D/2 - (D/2) % (1*stride_k);
-            const int stride_i = WARP_SIZE / stride_k;
-
-#pragma unroll
-            for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
-                const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
-
-#pragma unroll
-                for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
-                    const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
-
-                    tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
-                }
-            }
-        }
-
-        __syncthreads();
-
-        // Calculate tile of KQ:
+    // Calculate tile of KQ:
 #pragma unroll
-        for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
-            const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
+    for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) {
+        const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
 #pragma unroll
-            for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
-                mma_A K_A;
-                K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
-                KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
-            }
+        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
+            tile_A K_A;
+            load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
+            mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]);
         }
+    }
 
-        __syncthreads();
+#ifndef CP_ASYNC_AVAILABLE
+    __syncthreads(); // Only needed if tile_K == tile_V.
+#endif // CP_ASYNC_AVAILABLE
 
-        if (use_logit_softcap) {
-            static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
+    if (use_logit_softcap) {
+        static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
 #pragma unroll
-            for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
+        for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) {
 #pragma unroll
-                for (int l = 0; l < mma_C_KQ::ne; ++l) {
-                    KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
-                }
+            for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
             }
         }
+    }
 
-        if (maskh) {
-            static_assert(KQ_stride % (np       *mma_C_KQ::I) == 0, "bad loop size");
-            static_assert(ncols     % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
+    if (maskh) {
+        static_assert(KQ_stride % (np       *tile_C_KQ::I) == 0, "bad loop size");
+        static_assert(ncols     % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size");
 #pragma unroll
-            for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
-                const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
+        for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) {
+            const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
 #pragma unroll
-                for (int l = 0; l < mma_C_KQ::ne; ++l) {
-                    const int i = i0 + mma_C_KQ::get_i(l);
-                    const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
+            for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                const int i = i0 + tile_C_KQ::get_i(l);
+                const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l);
 
-                    KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
-                }
+                KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
             }
         }
+    }
 
-        // Calculate softmax for each KQ column using the current max. value.
-        // The divisor is stored in KQ_rowsum and will be applied at the end.
-        float2 KQ_max_new = KQ_max;
-        static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
+    // Calculate softmax for each KQ column using the current max. value.
+    // The divisor is stored in KQ_rowsum and will be applied at the end.
+    float2 KQ_max_new = KQ_max;
+    static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
 #pragma unroll
-        for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
+    for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
 #pragma unroll
-            for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
-                KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
-                KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
-            }
+        for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) {
+            KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
+            KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
         }
+    }
 
-        // Values per KQ column are spread across 8 threads, does not need full warp reduce:
+    // Values per KQ column are spread across 8 threads, does not need full warp reduce:
 #pragma unroll
-        for (int offset = 16; offset > 2; offset >>= 1) {
-            KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
-            KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
-        }
-
-        {
-            const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
-            KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
-            if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
-                KQ_max_scale.x = 0.0f;
-            }
-            if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
-                KQ_max_scale.y = 0.0f;
-            }
-            KQ_max = KQ_max_new;
-        }
+    for (int offset = 16; offset > 2; offset >>= 1) {
+        KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
+        KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
+    }
 
-        float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
-        static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
+    float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
+    static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
 #pragma unroll
-        for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
+    for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
 #pragma unroll
-            for (int l = 0; l < mma_C_KQ::ne; ++l) {
-                const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
-                const float diff = KQ_C[k].x[l] - KQ_max_l;
-                KQ_C[k].x[l] = expf(diff);
-                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                    KQ_C[k].x[l] = 0.0f;
-                }
+        for (int l = 0; l < tile_C_KQ::ne; ++l) {
+            const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y;
+            const float diff = KQ_C[k].x[l] - KQ_max_l;
+            KQ_C[k].x[l] = expf(diff);
 
-                if (l % 2 == 0) {
-                    KQ_rowsum_add.x += KQ_C[k].x[l];
-                } else {
-                    KQ_rowsum_add.y += KQ_C[k].x[l];
-                }
+            if (l % 2 == 0) {
+                KQ_rowsum_add.x += KQ_C[k].x[l];
+            } else {
+                KQ_rowsum_add.y += KQ_C[k].x[l];
             }
         }
+    }
+
+    {
+        const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
+        const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
+        KQ_max = KQ_max_new;
 
         // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
         KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
@@ -244,60 +197,179 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
         const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
 #pragma unroll
-        for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
+        for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
 #pragma unroll
-            for (int l = 0; l < mma_C_VKQ::ne; ++l) {
+            for (int l = 0; l < tile_C_VKQ::ne; ++l) {
                 VKQ_C[i].x[l] *= KQ_max_scale_h2;
             }
         }
+    }
+
+    // Convert KQ C tiles into B tiles for VKQ calculation:
+    tile_B B[KQ_stride/(np*2*tile_B::J)];
+    static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size");
+#pragma unroll
+    for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) {
+        B[k] = get_transposed(get_half2(KQ_C[k]));
+    }
 
-        // Convert KQ C tiles into B tiles for VKQ calculation:
-        mma_B B[KQ_stride/(np*2*mma_B::K)];
-        static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
+#ifdef CP_ASYNC_AVAILABLE
+    cp_async_wait_all();
+    __syncthreads();
+    if (!last_iter) {
+        flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV);
+    }
+#else
+    flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
+    __syncthreads();
+#endif // CP_ASYNC_AVAILABLE
+
+    // Calculate VKQ tile:
+#pragma unroll
+    for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
+        static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size");
 #pragma unroll
-        for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
-            B[k] = KQ_C[k].to_mma_B();
+        for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) {
+            const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
+
+            tile_A A;
+            load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
+            mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
+        }
+    }
+
+#ifndef CP_ASYNC_AVAILABLE
+    __syncthreads(); // Only needed if tile_K == tile_V.
+#endif // CP_ASYNC_AVAILABLE
+
+#else
+    NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+}
+
+template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
+static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
+        const float2 * const __restrict__ Q_f2,
+        const half2  * const __restrict__ K_h2,
+        const half2  * const __restrict__ V_h2,
+        const half   * const __restrict__ maskh,
+        float2       * const __restrict__ dstk,
+        float2       * const __restrict__ dstk_fixup,
+        const float scale,
+        const float slope,
+        const float logit_softcap,
+        const int ne01,
+        const int ne02,
+        const int stride_Q,
+        const int stride_KV,
+        const int stride_mask,
+        const int jt,
+        const int kb0_start,
+        const int kb0_stop) {
+#ifdef NEW_MMA_AVAILABLE
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps");
+    constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
+
+    static_assert(D         % nwarps == 0, "bad D");
+    static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
+
+    constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+
+    // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements:
+    extern __shared__ half2 tile_K[];
+#ifdef CP_ASYNC_AVAILABLE
+    half2 * tile_V = tile_K + KQ_stride*D2_padded;
+#else
+    half2 * tile_V = tile_K;
+#endif // CP_ASYNC_AVAILABLE
+
+    tile_B Q_B[D/(2*tile_B::J)];
+    tile_C_VKQ VKQ_C[D/tile_C_VKQ::I];
+
+    float2 KQ_rowsum = {0.0f, 0.0f};
+    float2    KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
+
+    // Temporarily load Q data into tile_K, will be loaded into registers afterwards.
+    // The loading is done with decreasing granularity for D for better memory bandwidth.
+    const half2 scale_h2 = make_half2(scale, scale);
+#pragma unroll
+    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+        const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+        const int k0_stop  =                             D/2 - (D/2) % (1*stride_k);
+        const int stride_j = WARP_SIZE / stride_k;
+
+        if (k0_start == k0_stop) {
+            continue;
+        }
+
+        if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
+            break;
         }
 
-        // Load V data into tile with decreasing granularity for D for better memory bandwidth:
-        static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
 #pragma unroll
-        for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-            const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
-            const int i0_stop  =                             D/2 - (D/2) % (1*stride_i);
-            const int stride_k = WARP_SIZE / stride_i;
+        for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
+            const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
 
+            if (jt*ncols + j < ne01) {
 #pragma unroll
-            for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
-                const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
+                    const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
+                    tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
+                }
+            } else {
 #pragma unroll
-                for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
-                    const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
-                    tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
+                    tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f);
                 }
             }
         }
+    }
 
-        __syncthreads();
+    __syncthreads();
 
-        // Calculate VKQ tile:
-#pragma unroll
-        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
-            static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
-#pragma unroll
-            for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
-                const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
+    {
+        const int j0 = (threadIdx.y / np) * tile_B::I;
 
-                mma_A A;
-                A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
-                VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
-            }
+#pragma unroll
+        for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
+            load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
         }
+    }
+
+    __syncthreads();
 
+    // Preload K data for first iteration when using cp_async:
+#ifdef CP_ASYNC_AVAILABLE
+    flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV);
+#endif // CP_ASYNC_AVAILABLE
+
+    // Iterate over ne11 == previous tokens:
+    for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
+        constexpr bool last_iter = false;
+        flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+            (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+    }
+    { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
+        constexpr bool last_iter = true;
+        flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+            (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+    }
+
+    // With cp_async there is no __syncthreads at the end of the iter,
+    //     there can be a race condition on shared memory access for combining/writing back results.
+#ifdef CP_ASYNC_AVAILABLE
+    if (nwarps*tile_B::I > KQ_stride) {
         __syncthreads();
     }
+#endif // CP_ASYNC_AVAILABLE
 
     // Finally, sum up partial KQ rowsums.
     // The partial sums are spread across 8 threads each, does not need full reduce.
@@ -310,26 +382,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     // Write VKQ accumulators to shared memory in column-major format.
     // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
     // Also for np > 1 the combination is done via these values in shared memory.
-    const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
+    const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data
 #pragma unroll
-    for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
-        const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
+    for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
+        const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
 
 #pragma unroll
-        for (int l = 0; l < mma_B::ne; ++l) {
-            const int k = k0 + mma_B::get_k(l);
+        for (int l = 0; l < tile_B::ne; ++l) {
+            const int k = k0 + tile_B::get_j(l);
 
-            tile_KV[j_cwd*D2_padded + k] = B.x[l];
+            tile_K[j_cwd*D2_padded + k] = B.x[l];
         }
     }
 
-    const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
-    const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
+    const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset
+    const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
     const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
 
-    if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
+    if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
         // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
-        ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+        ((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
     }
 
     __syncthreads();
@@ -337,11 +409,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     static_assert(np == 1 || np == 2 || np == 4, "bad np");
     if (np == 1) {
         // No combination is needed, the meta data can be directly written from registers to VRAM.
-        if (needs_fixup && threadIdx.x < mma_B::J) {
+        if (needs_fixup && threadIdx.x < tile_B::I) {
             float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
             dstk_fixup_meta[j_cwm] = KQ_cmr;
         }
-        if (is_fixup && threadIdx.x < mma_B::J) {
+        if (is_fixup && threadIdx.x < tile_B::I) {
             float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
             dstk_fixup_meta[j_cwm] = KQ_cmr;
         }
@@ -350,42 +422,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // Warps with threadIdx.y % np != 0 must NOT return early.
         // All threads must return simultaneously to avoid race conditions with work on the next tile.
 
-        float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
+        float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2;
 
         float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
-        if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
+        if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
             KQ_cm = meta_j[0];
         }
 
         float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
 #pragma unroll
-        for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
+        for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
             KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
         }
 
         const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
         float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
-        if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
+        if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
             KQ_crs = KQ_cms*meta_j[1];
         }
 #pragma unroll
-        for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
+        for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
             KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
         }
 
         // Write back combined meta data:
-        if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
-            meta_j[0] = KQ_cmn; // Combined max. KQ values.
-            meta_j[1] = KQ_crs; // Combined KQ rowsums.
-            meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
+        if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
+            *((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum.
         }
-        if (needs_fixup && threadIdx.x < mma_B::J) {
+        if (needs_fixup && threadIdx.x < tile_B::I) {
             float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
-            dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+            dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
-        if (is_fixup && threadIdx.x < mma_B::J) {
+        if (is_fixup && threadIdx.x < tile_B::I) {
             float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
-            dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+            dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
     }
 
@@ -404,6 +474,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             const int k0_stop  =                             D/2 - (D/2) % (1*stride_k);
             const int stride_j = WARP_SIZE / stride_k;
 
+            if (k0_start == k0_stop) {
+                continue;
+            }
+
             if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
                 break;
             }
@@ -411,12 +485,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 #pragma unroll
             for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
                 const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
-                const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
+                const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I;
 
                 if (!is_fixup && jt*ncols + j_dst >= ne01) {
                     continue;
                 }
-                const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
+                const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2;
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -424,8 +498,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                     float2 dstk_val = make_float2(0.0f, 0.0f);
 #pragma unroll
                     for (int ip = 0; ip < np; ++ip) {
-                        const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
-                        const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
+                        const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0];
+                        const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]);
                         dstk_val.x += dstk_val_add.x*KQ_crs;
                         dstk_val.y += dstk_val_add.y*KQ_crs;
                     }
@@ -450,7 +524,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         __syncthreads();
     }
 #else
-   NO_DEVICE_CODE;
+    NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
 
@@ -494,6 +568,11 @@ static __global__ void flash_attn_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
+#ifndef NEW_MMA_AVAILABLE
+    NO_DEVICE_CODE;
+    return;
+#endif // NEW_MMA_AVAILABLE
+
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -504,6 +583,10 @@ static __global__ void flash_attn_ext_f16(
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 
+    const int stride_Q    = nb01 / sizeof(float2);
+    const int stride_KV   = nb11 / sizeof(half2);
+    const int stride_mask = nb31 / sizeof(half);
+
     const int iter_k = ne11 / KQ_stride;
     const int iter_j = (ne01 + (ncols - 1)) / ncols;
 
@@ -535,14 +618,12 @@ static __global__ void flash_attn_ext_f16(
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
             flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
-                ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
-                jt, kb0_start, kb0_stop);
+                 ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
             flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
-                ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
-                jt, kb0_start, kb0_stop);
+                 ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
         }
 
         kbc += iter_k;
@@ -571,24 +652,27 @@ static __global__ void flash_attn_ext_f16(
     constexpr bool needs_fixup = false;
     flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
         (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
-        ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
-        jt, kb0_start, kb0_stop);
+         ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
 }
 
 template <int D, int cols_per_block>
 void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    typedef mma_A_I16K8<half2> mma_A;
-    typedef mma_B_J8K8<half2>  mma_B;
+    typedef tile<16, 8, half2> tile_A;
+    typedef tile< 8, 8, half2> tile_B;
 
-    static_assert(D              % mma_B::K == 0, "bad D");
-    static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
+    static_assert(D              % tile_B::J == 0, "bad D");
+    static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block");
 
     const ggml_tensor * KQV = dst;
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+    constexpr int KQ_stride = D <= 128 ? 64 : 32;
+    constexpr int nwarps    = (KQ_stride == 32 && cols_per_block <= 16) ?
+                              cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8);
 
-    constexpr int    KQ_stride     = D <= 128 ? 64 : 32;
-    constexpr int    nwarps        = (KQ_stride == 32 && cols_per_block <= 16) ?
-                                     cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
-    constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
+    const int    nrows_KQ      = cp_async_available(cc) ? 2*KQ_stride : KQ_stride;
+    const int    nrows_combine = nwarps*tile_B::J;
+    const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half);
 
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
index bbc0a35ae561669231f845aa8a5584cbc3922802..0a5656e4cb3101630835e9c5716b84e46fefcd20 100644 (file)
@@ -4,11 +4,12 @@
 //   https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
 //
 // Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
-// A is a row-major matrix with shape I x K.
-// B is a column-major matrix with shape K x J.
-// C is a column-major matrix with shape I x J.
-// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
-// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
+// A is a row-major matrix with shape M x K.
+// B is a column-major matrix with shape K x N.
+// C is a column-major matrix with shape M x N.
+// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
+// Note that J is measured in physical 32 bit elements instead of logical elements.
+// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
 // All matrix tiles have ne physical 32 bit elements per warp.
 //
 // As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
@@ -23,7 +24,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
 
 #ifdef NEW_MMA_AVAILABLE
     asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
-        : "+r"(ret) : "r"(x));
+        : "=r"(ret) : "r"(x));
 #else
     NO_DEVICE_CODE;
 #endif // defined(NEW_MMA_AVAILABLE)
@@ -52,407 +53,267 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
 
 #endif // CUDART_VERSION >= 11080
 
+static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
+    half2 ret;
+    *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
+    return ret;
+}
 
-template <typename T>
-struct mma_A_I16K4 {
-    static_assert(sizeof(T) == 4, "bad type size");
-
-    static constexpr int I  = 16;
-    static constexpr int K  = 4;
-    static constexpr int ne = 2;
-
-    T x[ne];
+namespace ggml_cuda_mma {
+
+    template <int I_, int J_, typename T>
+    struct tile {
+        static constexpr int I  = I_;
+        static constexpr int J  = J_;
+        static constexpr int ne = I * J / WARP_SIZE;
+        T x[ne] = {0};
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && (J == 4 || J == 8)) {
+                return threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l / 2) * 8 + threadIdx.x / 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
 
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = (l%2) * (I/2) + threadIdx.x / K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
-        return ret;
-    }
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 4) {
+                return threadIdx.x % 4;
+            } else if constexpr (I == 8 && J == 8) {
+                return 4 * l + threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return 2 * (threadIdx.x % 4) + l % 2;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+    };
+
+    template <int I_, int J_>
+    struct tile<I_, J_, half2> {
+        static constexpr int I  = I_;
+        static constexpr int J  = J_;
+        static constexpr int ne = I * J / WARP_SIZE;
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return l * 8 + threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l % 2) * 8 + threadIdx.x / 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
 
-    static __device__ __forceinline__ int get_k(const int /* l */) {
-        const int ret = threadIdx.x % K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
-        return ret;
-    }
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return l * 4 + threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l / 2) * 4 + threadIdx.x % 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+    };
 
-    __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
+    template <int I, int J>
+    static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
+        tile<I, J/2, half2> ret;
 #pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
+            ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
         }
-    }
-
-    __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
-#ifdef NEW_MMA_AVAILABLE
-        int * xi = (int *) x;
-        const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
-        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
-            : "+r"(xi[0]), "+r"(xi[1])
-            : "l"(xs));
-#else
-        load_generic(xs0, stride);
-#endif // NEW_MMA_AVAILABLE
-    }
-};
-
-template <typename T>
-struct mma_A_I16K8 {
-    static_assert(sizeof(T) == 4, "bad type size");
-
-    static constexpr int I  = 16;
-    static constexpr int K  = 8;
-    static constexpr int ne = 4;
-
-    T x[ne];
-
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
         return ret;
     }
 
-    static __device__ __forceinline__ int get_k(const int l) {
-        const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
+    static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
+        tile<8, 8, half2> ret;
+        ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
+        ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
+
         return ret;
     }
 
-    __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
+    template <int I, int J, typename T>
+    static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
 #pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        for (int l = 0; l < t.ne; ++l) {
+            t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
         }
     }
 
-    __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
+    template <typename T>
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
 #ifdef NEW_MMA_AVAILABLE
-        int * xi = (int * ) x;
-        const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
-        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
-            : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+        int * xi = (int *) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
+        asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "=r"(xi[0]), "=r"(xi[1])
             : "l"(xs));
 #else
-        GGML_UNUSED(xs0);
-        GGML_UNUSED(stride);
-        NO_DEVICE_CODE;
+        load_generic(t, xs0, stride);
 #endif // NEW_MMA_AVAILABLE
     }
 
-    __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
+    template <typename T>
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
 #ifdef NEW_MMA_AVAILABLE
-        int * xi = (int * ) x;
-        const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
-        asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
-            : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
+        int * xi = (int *) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
+        asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "=r"(xi[0]), "=r"(xi[1])
             : "l"(xs));
 #else
-        GGML_UNUSED(xs0);
-        GGML_UNUSED(stride);
-        NO_DEVICE_CODE;
+        load_generic(xs0, stride);
 #endif // NEW_MMA_AVAILABLE
     }
 
-    __device__ __forceinline__ void transpose() {
-        int * xi  = (int *) x;
-        xi[0] = ggml_cuda_movmatrix(xi[0]);
-
-        const int tmp = ggml_cuda_movmatrix(xi[1]);
-        xi[1] = ggml_cuda_movmatrix(xi[2]);
-        xi[2] = tmp;
-
-        xi[3] = ggml_cuda_movmatrix(xi[3]);
-    }
-};
-
-template <typename T>
-struct mma_B_J8K4 {
-    static_assert(sizeof(T) == 4, "bad type size");
-
-    static constexpr int J  = 8;
-    static constexpr int K  = 4;
-    static constexpr int ne = 1;
-
-    T x[ne];
-
-    static __device__ __forceinline__ int get_j(const int /* l */) {
-        const int ret = threadIdx.x / K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
-    }
-
-    static __device__ __forceinline__ int get_k(const int /* l */) {
-        const int ret = threadIdx.x % K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
-        return ret;
-    }
-
-    __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_j(l)*stride + get_k(l)];
-        }
-    }
-
-    __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
+    template <typename T>
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
 #ifdef NEW_MMA_AVAILABLE
-        int * xi = (int *) x;
-        const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
-        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
-            : "+r"(xi[0]) : "l"(xs));
+        int * xi = (int * ) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
+        asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+            : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
+            : "l"(xs));
 #else
-        load_generic(xs0, stride);
+        load_generic(t, xs0, stride);
 #endif // NEW_MMA_AVAILABLE
     }
-};
-
-template <typename T>
-struct mma_B_J8K8 {
-    static_assert(sizeof(T) == 4, "bad type size");
-
-    static constexpr int J  = 8;
-    static constexpr int K  = 8;
-    static constexpr int ne = 2;
 
-    T x[ne];
-
-    static __device__ __forceinline__ int get_j(const int /* l */) {
-        const int ret = threadIdx.x / (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
-    }
-
-    static __device__ __forceinline__ int get_k(const int l) {
-        const int ret = l * (K/2) + threadIdx.x % (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
-        return ret;
-    }
-
-    __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_j(l)*stride + get_k(l)];
-        }
-    }
-
-    __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
+    template <typename T>
+    static __device__ __forceinline__ void load_ldmatrix_trans(
+            tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
 #ifdef NEW_MMA_AVAILABLE
-        int * xi = (int *x;
-        const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
-        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
-            : "+r"(xi[0]), "+r"(xi[1])
+        int * xi = (int * ) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
+        asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
+            : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
             : "l"(xs));
 #else
-        load_generic(xs0, stride);
+        GGML_UNUSED(t);
+        GGML_UNUSED(xs0);
+        GGML_UNUSED(stride);
+        NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
     }
-};
-
-template <typename T>
-struct mma_C_I16J8 {};
-
-template <>
-struct mma_C_I16J8<int> {
-    static constexpr int I  = 16;
-    static constexpr int J  = 8;
-    static constexpr int ne = 4;
 
-    int x[ne] = {0};
-
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
-        return ret;
-    }
-
-    static __device__ __forceinline__ int get_j(const int l) {
-        const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
-    }
-
-    __device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
 #ifdef NEW_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
-            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+            : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
 #else
         // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[0]), "+r"(x[1])
-            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+            : "+r"(D.x[0]), "+r"(D.x[1])
+            : "r"(A.x[0]), "r"(B.x[0]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+            : "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[1]), "r"(B.x[0]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #else
-        GGML_UNUSED(mma_A);
-        GGML_UNUSED(mma_B);
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
         NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
     }
 
-    __device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
 #ifdef NEW_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
-            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
+            : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
 #else
         // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[0]), "+r"(x[1])
-            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+            : "+r"(D.x[0]), "+r"(D.x[1])
+            : "r"(A.x[0]), "r"(B.x[0]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+            : "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[1]), "r"(B.x[0]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[0]), "+r"(x[1])
-            : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
+            : "+r"(D.x[0]), "+r"(D.x[1])
+            : "r"(A.x[2]), "r"(B.x[1]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
+            : "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[3]), "r"(B.x[1]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #else
-        GGML_UNUSED(mma_A);
-        GGML_UNUSED(mma_B);
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
         NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
     }
-};
-
-template <>
-struct mma_C_I16J8<half2> {
-    static constexpr int I  = 16;
-    static constexpr int J  = 4;
-    static constexpr int ne = 2;
-
-    half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
-
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = l * (I/2) + threadIdx.x / J;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
-        return ret;
-    }
 
-    static __device__ __forceinline__ int get_j(const int /* l */) {
-        const int ret = threadIdx.x % J;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
-    }
-
-    __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
+    static __device__ __forceinline__ void mma(
+            tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
 #ifdef NEW_MMA_AVAILABLE
-        int * Axi = (int *) mma_A.x;
-        int * Bxi = (int *) mma_B.x;
-        int * xi  = (int *) x;
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
-            : "+r"(xi[0]), "+r"(xi[1])
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
             : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
 #else
         // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
         asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
-            : "+r"(xi[0]), "+r"(xi[1])
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
             : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
         asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
-            : "+r"(xi[0]), "+r"(xi[1])
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
             : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #else
-        GGML_UNUSED(mma_A);
-        GGML_UNUSED(mma_B);
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
         NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
     }
 
-    __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
-        mma_B_J8K8<half2> mma_B;
-
-        int * xi   = (int *) x;
-        int * Bxi  = (int *) mma_B.x;
-        Bxi[0] = ggml_cuda_movmatrix(xi[0]);
-        Bxi[1] = ggml_cuda_movmatrix(xi[1]);
-
-        return mma_B;
-    }
-};
-
-template <>
-struct mma_C_I16J8<float> {
-    static constexpr int I  = 16;
-    static constexpr int J  = 8;
-    static constexpr int ne = 4;
-
-    float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
-
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
-        return ret;
-    }
-
-    static __device__ __forceinline__ int get_j(const int l) {
-        const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
-    }
-
-    __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
 #ifdef NEW_MMA_AVAILABLE
-        int * Axi = (int *) mma_A.x;
-        int * Bxi = (int *) mma_B.x;
-        int * xi  = (int *) x;
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
-            : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
             : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
 #else
         // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
         asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
-            : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
             : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
         asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
-            : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
             : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #else
-        GGML_UNUSED(mma_A);
-        GGML_UNUSED(mma_B);
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
         NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
     }
 
-    __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
-        mma_B_J8K8<half2> mma_B;
-        mma_B.x[0] = make_half2(x[0], x[1]);
-        mma_B.x[1] = make_half2(x[2], x[3]);
-
-        int * Bxi  = (int *) mma_B.x;
-        Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
-        Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
-
-        return mma_B;
-    }
-
-    __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_j(l)*stride + get_i(l)];
-        }
-    }
-};
+}
index 5391542086cfd3080386360d0d20a41f4730e6d2..0451c65f302777d9514527dbac5eb9fa6668e03f 100644 (file)
@@ -7,6 +7,8 @@
 #include <climits>
 #include <cstdint>
 
+using namespace ggml_cuda_mma;
+
 #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
 #define MMQ_ITER_K 256
 #define MMQ_NWARPS 8
@@ -647,15 +649,15 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
-    typedef mma_A_I16K8<int> mma_A;
-    typedef mma_B_J8K8<int>  mma_B;
-    typedef mma_C_I16J8<int> mma_C;
+    typedef tile<16, 8, int> tile_A;
+    typedef tile< 8, 8, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
@@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const float * y_df = (const float *) y;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A[ntx][WARP_SIZE/QI8_0];
-    float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
+    tile_A A[ntx][WARP_SIZE/QI8_0];
+    float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
 
     const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
@@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
             const int k0 = k00 + k01;
 
-            A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+            load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
@@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
-            mma_B  B;
-            float dB[mma_C::ne/2];
+            tile_B B;
+            float dB[tile_C::ne/2];
 
-            B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
                     dB[l] =             y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C;
-                C.mma(A[n][k01/QI8_0], B);
+                tile_C C;
+                mma(C, A[n][k01/QI8_0], B);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
                 }
             }
         }
@@ -758,23 +760,23 @@ template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
-    typedef mma_A_I16K8<int> mma_A;
-    typedef mma_B_J8K8<int>  mma_B;
-    typedef mma_C_I16J8<int> mma_C;
+    typedef tile<16, 8, int> tile_A;
+    typedef tile< 8, 8, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_dm = (const half2 *) y;
 
-    mma_A    A[ntx][WARP_SIZE/QI8_1];
-    float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
+    tile_A   A[ntx][WARP_SIZE/QI8_1];
+    float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
 
     const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
@@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
             const int k0 = k00 + k01;
 
-            A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+            load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
@@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
-            mma_B    B;
-            float2 dsB[mma_C::ne/2];
+            tile_B   B;
+            float2 dsB[tile_C::ne/2];
 
-            B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C;
-                C.mma(A[n][k01/QI8_1], B);
+                tile_C C;
+                mma(C, A[n][k01/QI8_1], B);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
                 }
             }
         }
@@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_A_I16K4<int> mma_A;
-    typedef mma_A_I16K8<int> mma_A_K8;
-    typedef mma_B_J8K4<int>  mma_B;
-    typedef mma_C_I16J8<int> mma_C;
+    typedef tile<16, 4, int> tile_A;
+    typedef tile<16, 8, int> tile_A_8;
+    typedef tile< 8, 4, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const float * x_df = (const float *) x_qs + WARP_SIZE*2;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
 
-    mma_A   A[ntx][8];
-    float  dA[ntx][mma_C::ne/2][8];
+    tile_A  A[ntx][8];
+    float  dA[ntx][tile_C::ne/2][8];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
@@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
             const int k0 = k00 + k01;
 
-            ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+            load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
@@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
-            mma_B B[2];
-            float dB[mma_C::ne/2];
+            tile_B B[2];
+            float dB[tile_C::ne/2];
 
             // Here load_generic is faster than load_ldmatrix.
-            B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
-            B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),         MMQ_TILE_Y_K);
+            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C[2];
-                C[0].mma(A[n][k01/4 + 0], B[0]);
-                C[1].mma(A[n][k01/4 + 1], B[1]);
+                tile_C C[2];
+                mma(C[0], A[n][k01/4 + 0], B[0]);
+                mma(C[1], A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
                 }
             }
         }
@@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_A_I16K4<int> mma_A;
-    typedef mma_A_I16K8<int> mma_A_K8;
-    typedef mma_B_J8K4<int>  mma_B;
-    typedef mma_C_I16J8<int> mma_C;
+    typedef tile<16, 4, int> tile_A;
+    typedef tile<16, 8, int> tile_A_8;
+    typedef tile< 8, 4, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
 
-    mma_A   A[ntx][8];
-    float  dA[ntx][mma_C::ne/2][8];
-    float  mA[ntx][mma_C::ne/2][8];
+    tile_A  A[ntx][8];
+    float  dA[ntx][tile_C::ne/2][8];
+    float  mA[ntx][tile_C::ne/2][8];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
@@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
             const int k0 = k00 + k01;
 
-            ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+            load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
         }
     }
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
@@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        float2 dB[mma_C::ne/2];
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+        float2 dB[tile_C::ne/2];
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int j = j0 + mma_C::get_j(l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int j = j0 + tile_C::get_j(l);
 
             dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
         }
 
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
-            mma_B B[2];
+            tile_B B[2];
 
             // Here load_generic is faster than load_ldmatrix.
-            B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
-            B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),         MMQ_TILE_Y_K);
+            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
 
-            mma_C Cm[2];
+            tile_C Cm[2];
             if (k01 >= WARP_SIZE * 3/4) {
-                mma_A A1;
+                tile_A A1;
                 A1.x[0] = 0x01010101;
                 A1.x[1] = 0x01010101;
-                Cm[0].mma(A1, B[0]);
-                Cm[1].mma(A1, B[1]);
+                mma(Cm[0], A1, B[0]);
+                mma(Cm[1], A1, B[1]);
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C Cd[2];
+                tile_C Cd[2];
 
-                Cd[0].mma(A[n][k01/4 + 0], B[0]);
-                Cd[1].mma(A[n][k01/4 + 1], B[1]);
+                mma(Cd[0], A[n][k01/4 + 0], B[0]);
+                mma(Cd[1], A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
+                for (int l = 0; l < tile_C::ne; ++l) {
                     float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
                     if (k01 >= WARP_SIZE * 3/4) {
                         tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
                     }
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
                 }
             }
         }
 
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
-            float2 sB[mma_C::ne/2];
+            float2 sB[tile_C::ne/2];
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
             }
@@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
                 }
             }
         }
@@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 #ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_A_I16K4<int> mma_A;
-    typedef mma_B_J8K4<int>  mma_B;
-    typedef mma_C_I16J8<int> mma_C;
+    typedef tile<16, 4, int> tile_A;
+    typedef tile< 8, 4, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const float * x_df = (const float *) x_qs + WARP_SIZE*2;
@@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
 
-    mma_A   A[ntx][8];
-    int   scA[ntx][mma_C::ne/2][8];
-    float  dA[ntx][mma_C::ne/2];
+    tile_A   A[ntx][8];
+    int    scA[ntx][tile_C::ne/2][8];
+    float   dA[ntx][tile_C::ne/2];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
@@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
             const int k0 = k00 + k01;
 
-            A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),        MMQ_MMA_TILE_X_K_Q6_K);
-            A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+            load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),         MMQ_MMA_TILE_X_K_Q6_K);
+            load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
         }
 
 #pragma unroll
@@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
             const int k0 = k00 + k01;
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
                 const int      sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
                 const int8_t * sc        = (const int8_t *) &sc_packed;
@@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
             dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
         }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        float tmp[ntx][mma_C::ne] = {{0.0f}};
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+        float tmp[ntx][tile_C::ne] = {{0.0f}};
 
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
-            mma_B B[2];
-            float dB[mma_C::ne/2];
+            tile_B B[2];
+            float dB[tile_C::ne/2];
 
             // Here load_generic is faster than load_ldmatrix.
-            B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0        + k01, MMQ_TILE_Y_K);
-            B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
+            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0         + k01, MMQ_TILE_Y_K);
+            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C[2];
-                C[0].mma(A[n][k01/4 + 0], B[0]);
-                C[1].mma(A[n][k01/4 + 1], B[1]);
+                tile_C C[2];
+                mma(C[0], A[n][k01/4 + 0], B[0]);
+                mma(C[1], A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
+                for (int l = 0; l < tile_C::ne; ++l) {
                     tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
                 }
             }
@@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 #pragma unroll
         for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
+            for (int l = 0; l < tile_C::ne; ++l) {
+                sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
             }
         }
     }
@@ -2312,36 +2314,36 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
 static __device__ __forceinline__ void mmq_write_back_mma(
     const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
 
-    typedef mma_C_I16J8<int> mma_C;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
 #ifdef NEW_MMA_AVAILABLE
-    static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
+    static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
 #endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne; ++l) {
+                const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
 
                 if (j > j_max) {
                     continue;
                 }
 
-                const int i = i0 + n*mma_C::I + mma_C::get_i(l);
+                const int i = i0 + n*tile_C::I + tile_C::get_i(l);
 
                 if (need_check && i > i_max) {
                     continue;
                 }
 
-                dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
+                dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
             }
         }
     }