]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: use mma PTX instructions for FlashAttention (#11583)
authorJohannes Gäßler <redacted>
Sun, 2 Feb 2025 18:31:09 +0000 (19:31 +0100)
committerGitHub <redacted>
Sun, 2 Feb 2025 18:31:09 +0000 (19:31 +0100)
* CUDA: use mma PTX instructions for FlashAttention

* __shfl_sync workaround for movmatrix

* add __shfl_sync to HIP

Co-authored-by: Diego Devesa <redacted>
29 files changed:
Makefile
ggml/include/ggml.h
ggml/src/ggml-cuda/CMakeLists.txt
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/fattn-tile-f16.cu
ggml/src/ggml-cuda/fattn-tile-f32.cu
ggml/src/ggml-cuda/fattn-vec-f16.cuh
ggml/src/ggml-cuda/fattn-vec-f32.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cu [new file with mode: 0644]
ggml/src/ggml-cuda/fattn-wmma-f16.cuh
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/mma.cuh
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu [deleted file]
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu [deleted file]
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu [deleted file]
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu [deleted file]
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu [deleted file]
ggml/src/ggml-cuda/template-instances/generate_cu_files.py
ggml/src/ggml-cuda/vendors/hip.h
ggml/src/ggml-hip/CMakeLists.txt
ggml/src/ggml-musa/CMakeLists.txt

index ef152d2467ed564f759e9cf4a6ea4fcf634fcb3c..dc3de3cb14e44d01d0ffdaaded39e071e3c1ad37 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -596,7 +596,7 @@ ifdef GGML_RPC
        OBJ_GGML_EXT += ggml/src/ggml-rpc.o
 endif # GGML_RPC
 
-OBJ_CUDA_TMPL      = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu))
+OBJ_CUDA_TMPL      = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu))
 OBJ_CUDA_TMPL     += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu))
 
 ifdef GGML_CUDA_FA_ALL_QUANTS
index 1198dc1fd93785d1f7eebb04f615e278448524e2..5bd8d9c8b50232fb273a4a7808d32504876bb848 100644 (file)
@@ -1775,7 +1775,7 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   k);
 
-#define GGML_KQ_MASK_PAD 32
+#define GGML_KQ_MASK_PAD 64
 
     // q:    [n_embd, n_batch,     n_head,    1]
     // k:    [n_embd, n_kv,        n_head_kv, 1]
index 14761650f8a8e7efe5beef5a40b172c69113285f..119fd39b8e437fe78af2c28ab8740ea7b1467a64 100644 (file)
@@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND)
     list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
 
     file(GLOB   GGML_SOURCES_CUDA "*.cu")
-    file(GLOB   SRCS "template-instances/fattn-wmma*.cu")
+    file(GLOB   SRCS "template-instances/fattn-mma*.cu")
     list(APPEND GGML_SOURCES_CUDA ${SRCS})
     file(GLOB   SRCS "template-instances/mmq*.cu")
     list(APPEND GGML_SOURCES_CUDA ${SRCS})
index 8d8d3932e0e5864a9343b467079a53f3f7107bb0..88be8fc8a6ec6d19972516deb240cda14cead409 100644 (file)
@@ -148,7 +148,7 @@ typedef float2 dfloat2;
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
-#define INT8_MMA_AVAILABLE
+#define NEW_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 
 #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@@ -159,11 +159,13 @@ static constexpr bool fast_fp16_available(const int cc) {
     return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
 }
 
+// Any FP16 tensor cores are available.
 static constexpr bool fp16_mma_available(const int cc) {
     return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
 }
 
-static constexpr bool int8_mma_available(const int cc) {
+// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
+static constexpr bool new_mma_available(const int cc) {
     return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
 }
 
index ee9752da6a845d08caf1440921262107921f4f0f..cfd7c0f4475dc2e1829697161de11f9d4ed254a9 100644 (file)
@@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
+template<int D, int ncols, int KQ_stride> // D == head size
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_stream_k_fixup(
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
+    const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
+
+    const int iter_k = ne11 / KQ_stride;
+    const int iter_j = (ne01 + (ncols - 1)) / ncols;
+
+    const int bidx0 = blockIdx.x;
+
+    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
+    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
+
+    const bool did_not_have_any_data   = kbc0 == kbc0_stop;
+    const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
+    const bool did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
+    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
+        return;
+    }
+
+    const int channel = kbc0 / (iter_k*iter_j);
+    const int jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
+
+    dst += jt*ncols*ne02*D + channel*D;
+
+    // Load the partial result that needs a fixup:
+    float dst_val[ncols] = {0.0f};
+    float max_val[ncols] = {0.0f};
+    float rowsum[ncols]  = {0.0f};
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (jt*ncols + j >= ne01) {
+            break;
+        }
+        dst_val[j] = dst[j*ne02*D + threadIdx.x];
+
+        const float2 tmp = dst_fixup[bidx0*ncols + j];
+        max_val[j] = tmp.x;
+        rowsum[j]  = tmp.y;
+    }
+
+    // Iterate over previous blocks and compute the combined results.
+    // All CUDA blocks that get here must have a previous block that needs a fixup.
+    int bidx = bidx0 - 1;
+    int kbc_stop = kbc0;
+    while(true) {
+        const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
+        if (kbc == kbc_stop) { // Did not have any data.
+            bidx--;
+            kbc_stop = kbc;
+            continue;
+        }
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            if (jt*ncols + j >= ne01) {
+                break;
+            }
+            const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
+
+            const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
+
+            // Scale the current and new value accumulators depending on the max. values.
+            const float max_val_new = fmaxf(max_val[j], tmp.x);
+
+            const float diff_val = max_val[j] - max_val_new;
+            const float diff_add = tmp.x      - max_val_new;
+
+            const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
+            const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
+
+            dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
+            rowsum[j]  = scale_val*rowsum[j]  + scale_add*tmp.y;
+
+            max_val[j] = max_val_new;
+        }
+
+        // If this block started in a previous tile we are done and don't need to combine additional partial results.
+        if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
+            break;
+        }
+        bidx--;
+        kbc_stop = kbc;
+    }
+
+    // Write back final result:
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (jt*ncols + j >= ne01) {
+            return;
+        }
+        dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
+    }
+}
+
 template<int D, int parallel_blocks> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
@@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) {
     }
 }
 
-template <int D, int parallel_blocks>
+// parallel_blocks == 0 is stream-k decomposition
+template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
-    const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
+    const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
 ) {
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
@@ -603,20 +702,23 @@ void launch_fattn(
 
     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
 
+    GGML_ASSERT(Q->ne[3] == 1);
+
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
 
     ggml_cuda_pool_alloc<half>   K_f16(pool);
     ggml_cuda_pool_alloc<half>   V_f16(pool);
     ggml_cuda_pool_alloc<float>  dst_tmp(pool);
     ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
 
-    char * K_data = (char *) K->data;
+    const char * K_data = (const char *) K->data;
     size_t nb11 = K->nb[1];
     size_t nb12 = K->nb[2];
     size_t nb13 = K->nb[3];
 
-    char * V_data = (char *) V->data;
+    const char * V_data = (const char *) V->data;
     size_t nb21 = V->nb[1];
     size_t nb22 = V->nb[2];
     size_t nb23 = V->nb[3];
@@ -649,39 +751,60 @@ void launch_fattn(
         nb23 = nb23*bs*sizeof(half)/ts;
     }
 
-    if (parallel_blocks > 1) {
-        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
-        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
-    }
+    const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
+    const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
 
     const dim3 block_dim(WARP_SIZE, nwarps, 1);
-    const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
-    const int  shmem = 0;
+    dim3 blocks_num;
+    if (parallel_blocks == 0) {
+        // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
+        const int tiles_nwaves  = (ntiles_total - nsm - 1) / nsm;
+        const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
+        const bool short_context = K->ne[1] < 4096;
+
+        const int nblocks_stream_k = 2*nsm;
+
+        blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
+        blocks_num.y = 1;
+        blocks_num.z = 1;
+
+        dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
+    } else {
+        blocks_num.x = parallel_blocks*ntiles_x;
+        blocks_num.y = Q->ne[2];
+        blocks_num.z = Q->ne[3];
+
+        if (parallel_blocks > 1) {
+            dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+            dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+        }
+    }
+
 
     float scale         = 1.0f;
     float max_bias      = 0.0f;
     float logit_softcap = 0.0f;
 
-    memcpy(&scale,         (float *) KQV->op_params + 0, sizeof(float));
-    memcpy(&max_bias,      (float *) KQV->op_params + 1, sizeof(float));
-    memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
+    memcpy(&scale,         (const float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (const float *) KQV->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
     if (logit_softcap != 0.0f) {
         scale /= logit_softcap;
     }
 
     const uint32_t n_head      = Q->ne[2];
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+    const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
 
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-    fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
+    fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
         (const char *) Q->data,
         K_data,
         V_data,
         mask ? ((const char *) mask->data) : nullptr,
-        (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+        (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -693,16 +816,22 @@ void launch_fattn(
     );
     CUDA_CHECK(cudaGetLastError());
 
-    if ((parallel_blocks) == 1) {
-        return;
-    }
+    if constexpr (parallel_blocks == 0) {
+        if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+            const dim3 block_dim_combine(D, 1, 1);
+            const dim3 blocks_num_combine = blocks_num;
 
-    const dim3 block_dim_combine(D, 1, 1);
-    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
-    const int  shmem_combine = 0;
+            flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
+                <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
+        }
+    } else if constexpr (parallel_blocks > 1) {
+        const dim3 block_dim_combine(D, 1, 1);
+        const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
 
-    flash_attn_combine_results<D, parallel_blocks>
-        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
-        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+        flash_attn_combine_results<D, parallel_blocks>
+            <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
+            (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+    }
     CUDA_CHECK(cudaGetLastError());
 }
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
new file mode 100644 (file)
index 0000000..05bc91a
--- /dev/null
@@ -0,0 +1,637 @@
+#include "common.cuh"
+#include "mma.cuh"
+#include "fattn-common.cuh"
+
+template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
+static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
+        const float2 * const __restrict__ Q_f2,
+        const half2  * const __restrict__ K_h2,
+        const half2  * const __restrict__ V_h2,
+        const half   * const __restrict__ maskh,
+        float2       * const __restrict__ dstk,
+        float2       * const __restrict__ dstk_fixup,
+        const float scale,
+        const float slope,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3,
+        const int jt,
+        const int kb0_start,
+        const int kb0_stop) {
+#ifdef NEW_MMA_AVAILABLE
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    typedef mma_A_I16K8<half2> mma_A;
+    typedef mma_B_J8K8<half2>  mma_B;
+    typedef mma_C_I16J8<float> mma_C_KQ;
+    typedef mma_C_I16J8<half2> mma_C_VKQ;
+
+    static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
+    constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
+
+    static_assert(D         % nwarps == 0, "bad D");
+    static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
+
+    constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+    extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
+
+    const int stride_Q    = nb01 / sizeof(float2);
+    const int stride_KV   = nb11 / sizeof(half2);
+    const int stride_mask = nb31 / sizeof(half);
+
+    mma_B Q_B[D/(2*mma_B::K)];
+    mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
+
+    float2    KQ_rowsum = {0.0f, 0.0f};
+    float2       KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
+    float2 KQ_max_scale = {0.0f, 0.0f};
+
+    // Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
+    // The loading is done with decreasing granularity for D for better memory bandwidth.
+    const half2 scale_h2 = make_half2(scale, scale);
+#pragma unroll
+    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+        const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+        const int k0_stop  =                             D/2 - (D/2) % (1*stride_k);
+        const int stride_j = WARP_SIZE / stride_k;
+
+        if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
+            break;
+        }
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
+            const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+            if (jt*ncols + j < ne01) {
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
+                    tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
+                }
+            } else {
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
+                }
+            }
+        }
+    }
+
+    __syncthreads();
+
+    {
+        const int j0 = (threadIdx.y / np) * mma_B::J;
+
+#pragma unroll
+        for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
+            Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
+        }
+    }
+
+    __syncthreads();
+
+    // Iterate over ne11 == previous tokens:
+    for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
+        const int k_VKQ_0 = kb0*KQ_stride;
+        mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
+
+        // Load K data into tile with decreasing granularity for D for better memory bandwidth:
+        static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
+#pragma unroll
+        for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+            const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+            const int k0_stop  =                             D/2 - (D/2) % (1*stride_k);
+            const int stride_i = WARP_SIZE / stride_k;
+
+#pragma unroll
+            for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
+                const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+#pragma unroll
+                for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
+                    const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
+                }
+            }
+        }
+
+        __syncthreads();
+
+        // Calculate tile of KQ:
+#pragma unroll
+        for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
+            const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
+                mma_A K_A;
+                K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
+                KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
+            }
+        }
+
+        __syncthreads();
+
+        if (use_logit_softcap) {
+            static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+            for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
+#pragma unroll
+                for (int l = 0; l < mma_C_KQ::ne; ++l) {
+                    KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
+                }
+            }
+        }
+
+        if (maskh) {
+            static_assert(KQ_stride % (np       *mma_C_KQ::I) == 0, "bad loop size");
+            static_assert(ncols     % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
+#pragma unroll
+            for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
+                const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
+#pragma unroll
+                for (int l = 0; l < mma_C_KQ::ne; ++l) {
+                    const int i = i0 + mma_C_KQ::get_i(l);
+                    const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
+
+                    KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
+                }
+            }
+        }
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+        float2 KQ_max_new = KQ_max;
+        static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
+#pragma unroll
+            for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
+                KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
+                KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
+            }
+        }
+
+        // Values per KQ column are spread across 8 threads, does not need full warp reduce:
+#pragma unroll
+        for (int offset = 16; offset > 2; offset >>= 1) {
+            KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
+            KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
+        }
+
+        {
+            const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
+            KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
+            if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
+                KQ_max_scale.x = 0.0f;
+            }
+            if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
+                KQ_max_scale.y = 0.0f;
+            }
+            KQ_max = KQ_max_new;
+        }
+
+        float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
+        static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
+#pragma unroll
+            for (int l = 0; l < mma_C_KQ::ne; ++l) {
+                const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
+                const float diff = KQ_C[k].x[l] - KQ_max_l;
+                KQ_C[k].x[l] = expf(diff);
+                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                    KQ_C[k].x[l] = 0.0f;
+                }
+
+                if (l % 2 == 0) {
+                    KQ_rowsum_add.x += KQ_C[k].x[l];
+                } else {
+                    KQ_rowsum_add.y += KQ_C[k].x[l];
+                }
+            }
+        }
+
+        // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+        KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
+        KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
+
+        const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
+#pragma unroll
+        for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
+#pragma unroll
+            for (int l = 0; l < mma_C_VKQ::ne; ++l) {
+                VKQ_C[i].x[l] *= KQ_max_scale_h2;
+            }
+        }
+
+        // Convert KQ C tiles into B tiles for VKQ calculation:
+        mma_B B[KQ_stride/(np*2*mma_B::K)];
+        static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
+            B[k] = KQ_C[k].to_mma_B();
+        }
+
+        // Load V data into tile with decreasing granularity for D for better memory bandwidth:
+        static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
+#pragma unroll
+        for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+            const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
+            const int i0_stop  =                             D/2 - (D/2) % (1*stride_i);
+            const int stride_k = WARP_SIZE / stride_i;
+
+#pragma unroll
+            for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
+                const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
+
+#pragma unroll
+                for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
+                    const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
+
+                    tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
+                }
+            }
+        }
+
+        __syncthreads();
+
+        // Calculate VKQ tile:
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
+            static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
+#pragma unroll
+            for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
+                const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
+
+                mma_A A;
+                A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
+                VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
+            }
+        }
+
+        __syncthreads();
+    }
+
+    // Finally, sum up partial KQ rowsums.
+    // The partial sums are spread across 8 threads each, does not need full reduce.
+#pragma unroll
+    for (int offset = 16; offset > 2; offset >>= 1) {
+        KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE);
+        KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE);
+    }
+
+    // Write VKQ accumulators to shared memory in column-major format.
+    // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
+    // Also for np > 1 the combination is done via these values in shared memory.
+    const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
+#pragma unroll
+    for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
+        const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
+
+#pragma unroll
+        for (int l = 0; l < mma_B::ne; ++l) {
+            const int k = k0 + mma_B::get_k(l);
+
+            tile_KV[j_cwd*D2_padded + k] = B.x[l];
+        }
+    }
+
+    const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
+    const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
+    const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
+
+    if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
+        // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+        ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+    }
+
+    __syncthreads();
+
+    static_assert(np == 1 || np == 2 || np == 4, "bad np");
+    if (np == 1) {
+        // No combination is needed, the meta data can be directly written from registers to VRAM.
+        if (needs_fixup && threadIdx.x < mma_B::J) {
+            float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+            dstk_fixup_meta[j_cwm] = KQ_cmr;
+        }
+        if (is_fixup && threadIdx.x < mma_B::J) {
+            float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+            dstk_fixup_meta[j_cwm] = KQ_cmr;
+        }
+    } else if (threadIdx.y % np == 0) {
+        // Combine the meta data for parallel warps via shared memory.
+        // Warps with threadIdx.y % np != 0 must NOT return early.
+        // All threads must return simultaneously to avoid race conditions with work on the next tile.
+
+        float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
+
+        float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
+        if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
+            KQ_cm = meta_j[0];
+        }
+
+        float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
+#pragma unroll
+        for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
+            KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
+        }
+
+        const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
+        float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
+        if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
+            KQ_crs = KQ_cms*meta_j[1];
+        }
+#pragma unroll
+        for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
+            KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
+        }
+
+        // Write back combined meta data:
+        if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
+            meta_j[0] = KQ_cmn; // Combined max. KQ values.
+            meta_j[1] = KQ_crs; // Combined KQ rowsums.
+            meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
+        }
+        if (needs_fixup && threadIdx.x < mma_B::J) {
+            float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+            dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+        }
+        if (is_fixup && threadIdx.x < mma_B::J) {
+            float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+            dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+        }
+    }
+
+    if (np > 1) {
+        __syncthreads();
+    }
+
+    if (np == 1 || threadIdx.y % np == 0) {
+        // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
+        // The values after that are for the partial results of the individual blocks.
+        float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
+
+#pragma unroll
+        for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+            const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+            const int k0_stop  =                             D/2 - (D/2) % (1*stride_k);
+            const int stride_j = WARP_SIZE / stride_k;
+
+            if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
+                break;
+            }
+
+#pragma unroll
+            for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
+                const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+                const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
+
+                if (!is_fixup && jt*ncols + j_dst >= ne01) {
+                    continue;
+                }
+                const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    float2 dstk_val = make_float2(0.0f, 0.0f);
+#pragma unroll
+                    for (int ip = 0; ip < np; ++ip) {
+                        const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
+                        const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
+                        dstk_val.x += dstk_val_add.x*KQ_crs;
+                        dstk_val.y += dstk_val_add.y*KQ_crs;
+                    }
+
+                    if (!needs_fixup && !is_fixup) {
+                        const float KQ_rowsum_j = meta_j[1];
+                        dstk_val.x /= KQ_rowsum_j;
+                        dstk_val.y /= KQ_rowsum_j;
+                    }
+
+                    if (is_fixup) {
+                        dstk_fixup_data[j_dst*(D/2) + k] = dstk_val;
+                    } else {
+                        dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val;
+                    }
+                }
+            }
+        }
+    }
+
+    if (np > 1) {
+        __syncthreads();
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+}
+
+template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap>
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 2)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride");
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+
+    const int iter_k = ne11 / KQ_stride;
+    const int iter_j = (ne01 + (ncols - 1)) / ncols;
+
+    // kbc == k block continuous, current index in continuous ijk space.
+    int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x;
+    const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x;
+
+    // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
+    // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
+    // In the most general case >2 seams can fall into the same tile.
+
+    // kb0 == k start index when in the output tile.
+    int kb0_start = kbc % iter_k;
+    int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
+    while (kbc < kbc_stop && kb0_stop == iter_k) {
+        const int channel = kbc / (iter_k*iter_j);
+        const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+
+        const float2 * Q_f2  = (const float2 *) (Q + nb02* channel);
+        const half2  * K_h2  = (const half2  *) (K + nb12*(channel / gqa_ratio));
+        const half2  * V_h2  = (const half2  *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
+        const half   * maskh = mask ? (const half  *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
+        float2       * dstk  = ((float2 *) dst) + channel*(D/2);
+
+        const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
+
+        constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+        if (kb0_start == 0) {
+            constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
+            flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
+                (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
+                ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
+                jt, kb0_start, kb0_stop);
+        } else {
+            constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
+            flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
+                (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
+                ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
+                jt, kb0_start, kb0_stop);
+        }
+
+        kbc += iter_k;
+        kbc -= kbc % iter_k;
+
+        kb0_start = 0;
+        kb0_stop  = min(iter_k, kbc_stop - kbc);
+    }
+
+    if (kbc >= kbc_stop) {
+        return;
+    }
+
+    const int channel = kbc / (iter_k*iter_j);
+    const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+
+    const float2 * Q_f2  = (const float2 *) (Q + nb02* channel);
+    const half2  * K_h2  = (const half2  *) (K + nb12*(channel / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
+    const half   * maskh = mask ? (const half  *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
+    float2       * dstk  = ((float2 *) dst) + channel*(D/2);
+
+    const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
+
+    constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
+    constexpr bool needs_fixup = false;
+    flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
+        (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
+        ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
+        jt, kb0_start, kb0_stop);
+}
+
+template <int D, int cols_per_block>
+void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    typedef mma_A_I16K8<half2> mma_A;
+    typedef mma_B_J8K8<half2>  mma_B;
+
+    static_assert(D              % mma_B::K == 0, "bad D");
+    static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
+
+    const ggml_tensor * KQV = dst;
+
+    constexpr int    KQ_stride     = D <= 128 ? 64 : 32;
+    constexpr int    nwarps        = (KQ_stride == 32 && cols_per_block <= 16) ?
+                                     cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
+    constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    fattn_kernel_t fattn_kernel;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
+    } else {
+        constexpr bool use_logit_softcap = true;
+        fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
+    }
+    launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+}
+
+#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block)                          \
+    template void ggml_cuda_flash_attn_ext_mma_f16_case                     \
+    <D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_MMA_F16_CASE( 64,  8);
+extern DECL_FATTN_MMA_F16_CASE( 80,  8);
+extern DECL_FATTN_MMA_F16_CASE( 96,  8);
+extern DECL_FATTN_MMA_F16_CASE(112,  8);
+extern DECL_FATTN_MMA_F16_CASE(128,  8);
+extern DECL_FATTN_MMA_F16_CASE(256,  8);
+
+extern DECL_FATTN_MMA_F16_CASE( 64, 16);
+extern DECL_FATTN_MMA_F16_CASE( 80, 16);
+extern DECL_FATTN_MMA_F16_CASE( 96, 16);
+extern DECL_FATTN_MMA_F16_CASE(112, 16);
+extern DECL_FATTN_MMA_F16_CASE(128, 16);
+extern DECL_FATTN_MMA_F16_CASE(256, 16);
+
+extern DECL_FATTN_MMA_F16_CASE( 64, 32);
+extern DECL_FATTN_MMA_F16_CASE( 80, 32);
+extern DECL_FATTN_MMA_F16_CASE( 96, 32);
+extern DECL_FATTN_MMA_F16_CASE(112, 32);
+extern DECL_FATTN_MMA_F16_CASE(128, 32);
+extern DECL_FATTN_MMA_F16_CASE(256, 32);
+
+extern DECL_FATTN_MMA_F16_CASE( 64, 64);
+extern DECL_FATTN_MMA_F16_CASE( 80, 64);
+extern DECL_FATTN_MMA_F16_CASE( 96, 64);
+extern DECL_FATTN_MMA_F16_CASE(112, 64);
+extern DECL_FATTN_MMA_F16_CASE(128, 64);
+extern DECL_FATTN_MMA_F16_CASE(256, 64);
index 4d314dacb1958c8a169b61d5b3ac3fa9e58012eb..d4edbad07f26a2842299e3bcb8a5df480ef74dc9 100644 (file)
@@ -45,7 +45,17 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne2,
         const int ne3) {
 #ifdef FP16_AVAILABLE
+
+#ifndef FLASH_ATTN_AVAILABLE
+    NO_DEVICE_CODE;
+    return;
+#endif // FLASH_ATTN_AVAILABLE
+
     // Skip unused kernel variants for faster compilation:
+#ifdef FP16_MMA_AVAILABLE
+    NO_DEVICE_CODE;
+    return;
+#endif // FP16_MMA_AVAILABLE
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
         return;
@@ -288,16 +298,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
         case  64: {
-            constexpr int      D = 64;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 64;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         case 128: {
-            constexpr int      D = 128;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 128;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
index bb33604470fdd461b7763c980067bf099b145eef..0d274f33255b7ee03f4f7aaa82f49557fe4c565c 100644 (file)
@@ -48,7 +48,12 @@ static __global__ void flash_attn_tile_ext_f32(
     NO_DEVICE_CODE;
     return;
 #endif // FLASH_ATTN_AVAILABLE
+
     // Skip unused kernel variants for faster compilation:
+#ifdef FP16_MMA_AVAILABLE
+    NO_DEVICE_CODE;
+    return;
+#endif // FP16_MMA_AVAILABLE
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
         return;
@@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
         case  64: {
-            constexpr int      D = 64;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 64;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         case 128: {
-            constexpr int      D = 128;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 128;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
index 34a2992c769b9d4d338664bd338db0526c4b5c00..d9ac4424606e40cfad933ebbb8082ad5900c8352 100644 (file)
@@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne2,
         const int ne3) {
 #ifdef FP16_AVAILABLE
+
+#ifndef FLASH_ATTN_AVAILABLE
+    NO_DEVICE_CODE;
+    return;
+#endif // FLASH_ATTN_AVAILABLE
+
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
-    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+    constexpr size_t nbytes_shared = 0;
+    launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>
index a28fc8b7fc893ea5ed8a4fb9a746e84fcdf5a2b3..6ef8f9dcc27564b019dee207c25e4013012c7df2 100644 (file)
@@ -41,6 +41,11 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
+#ifndef FLASH_ATTN_AVAILABLE
+    NO_DEVICE_CODE;
+    return;
+#endif // FLASH_ATTN_AVAILABLE
+
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
-    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+    constexpr size_t nbytes_shared = 0;
+    launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
new file mode 100644 (file)
index 0000000..1054ff9
--- /dev/null
@@ -0,0 +1,648 @@
+// Old and deprecated WMMA FlashAttention implementation.
+// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
+// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-wmma-f16.cuh"
+
+#ifdef FP16_MMA_AVAILABLE
+#include <mma.h>
+#endif // FP16_MMA_AVAILABLE
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
+    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
+
+    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+    constexpr int frag_m = ncols == 8 ? 32 : 16;
+    constexpr int frag_n = ncols == 8 ?  8 : 16;
+    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
+
+    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+    constexpr int D_padded = D + 8;
+    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
+    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
+    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
+
+    const int stride_Q  = nb01 / sizeof(float);
+    const int stride_KV = nb11 / sizeof(half);
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+    const half2 slope2 = make_half2(slopef, slopef);
+
+    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
+
+    frag_b Q_b[D/16][ncols/frag_n];
+
+    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+    constexpr int mem_KQ = ncols*kqs_padded*kqar;
+    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+    float * KQ_f = (float *) KQ;
+    half2 * KQ2 = (half2 *) KQ;
+
+    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
+    float       KQ_max_f[ncols/nwarps];
+    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_f[j] = -FLT_MAX/2.0f;
+    }
+
+    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+    half2       KQ_max_h2[ncols/nwarps];
+    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+    }
+
+    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+    half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                break;
+            }
+            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+        }
+    }
+
+    // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+        }
+    }
+
+    __syncthreads();
+
+    // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+    for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+        }
+    }
+
+    __syncthreads();
+
+    // Iterate over ne11 == previous tokens:
+    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+        // Calculate tile of KQ:
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+            frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+            }
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+                frag_a_K K_a;
+                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+                }
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (std::is_same<KQ_acc_t, float>::value) {
+                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+
+                    if (use_logit_softcap) {
+                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+                    }
+                }
+
+                float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = warp_reduce_max(KQ_max_new);
+
+                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_f[j0/nwarps] = expf(diff);
+                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                    KQ_max_scale_f[j0/nwarps] = 0.0f;
+                }
+                KQ_max_f[j0/nwarps] = KQ_max_new;
+
+                float KQ_rowsum_add = 0.0f;
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
+                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+                    }
+                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
+                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+            } else {
+                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+
+                    if (use_logit_softcap) {
+                        // There is no dedicated tangens hyperbolicus function for half2.
+                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
+                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
+                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+
+                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
+                    }
+                }
+
+                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+                KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
+                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
+                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
+                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+            }
+        }
+
+        __syncthreads();
+
+        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+                nvcuda::wmma::load_matrix_sync(
+                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+                    KQ + j0*(kqar*kqs_padded) + k,
+                    kqar*kqs_padded);
+            }
+        }
+
+        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+            }
+
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+                frag_a_V v_a;
+                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+                }
+            }
+        }
+
+        __syncthreads();
+
+        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync(
+                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+                    D_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            half2 VKQ_scale;
+            if (std::is_same<KQ_acc_t, float>::value) {
+                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+            } else {
+                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                    break;
+                }
+
+                half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int l = 0; l < VKQ_ratio; ++l) {
+                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+                }
+                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j_VKQ = j0 + threadIdx.y;
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+
+        float KQ_rowsum_j;
+        if (std::is_same<KQ_acc_t, float>::value) {
+            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+        } else {
+            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            float dst_val = VKQ[j_VKQ*D_padded + i];
+            if (parallel_blocks == 1) {
+                dst_val /= KQ_rowsum_j;
+            }
+            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+        }
+
+        if (parallel_blocks == 1 || threadIdx.x != 0) {
+            continue;
+        }
+
+        float2 dst_meta_val;
+        if (std::is_same<KQ_acc_t, float>::value) {
+            dst_meta_val.x = KQ_max_f[j0/nwarps];
+        } else {
+            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+        }
+        dst_meta_val.y = KQ_rowsum_j;
+        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+}
+
+constexpr int get_max_power_of_2(int x) {
+    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
+
+template <int D, int cols_per_block, typename KQ_acc_t>
+void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    constexpr int nwarps = 4;
+
+    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
+    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (4*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 4;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
+        launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        return;
+    }
+    if (2*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 2;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
+        launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        return;
+    }
+    constexpr int parallel_blocks = 1;
+    fattn_kernel_t fattn_kernel;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    } else {
+        constexpr bool use_logit_softcap = true;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    }
+    launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+}
+
+void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+
+    if (prec != GGML_PREC_DEFAULT) {
+        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+            constexpr int cols_per_block = 16;
+            switch (Q->ne[0]) {
+                case 64:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+                    break;
+                case 80:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+                    break;
+                case 96:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+                    break;
+                case 112:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+                    break;
+                case 128:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                    break;
+                case 256:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
+                    break;
+                default:
+                    GGML_ABORT("fatal error");
+                    break;
+            }
+        } else {
+            constexpr int cols_per_block = 32;
+            switch (Q->ne[0]) {
+                case 64:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+                    break;
+                case 80:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+                    break;
+                case 96:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+                    break;
+                case 112:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+                    break;
+                case 128:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                    break;
+                // case 256:
+                //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                //     break;
+                default:
+                    GGML_ABORT("fatal error");
+                    break;
+            }
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+        constexpr int cols_per_block = 8;
+        switch (Q->ne[0]) {
+            case 64:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+                break;
+            case 96:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+                break;
+            case 128:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+                break;
+            case 256:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 16;
+        switch (Q->ne[0]) {
+            case 64:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+                break;
+            case 80:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+                break;
+            case 96:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+                break;
+            case 112:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+                break;
+            case 128:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+                break;
+            case 256:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    switch (Q->ne[0]) {
+        case 64:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+            break;
+        case 80:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+            break;
+        case 96:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+            break;
+        case 112:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+            break;
+        case 128:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+            break;
+        case 256:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
index 860d0e6dc2fa467d9da9e076027649eee22c1d71..beeea95eb1d629be62560df18a3a35e435434b65 100644 (file)
@@ -1,543 +1,3 @@
 #include "common.cuh"
-#include "fattn-common.cuh"
 
-#ifdef FP16_MMA_AVAILABLE
-#include <mma.h>
-#endif // FP16_MMA_AVAILABLE
-
-// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
-template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-static __global__ void flash_attn_ext_f16(
-        const char * __restrict__ Q,
-        const char * __restrict__ K,
-        const char * __restrict__ V,
-        const char * __restrict__ mask,
-        float      * __restrict__ dst,
-        float2     * __restrict__ dst_meta,
-        const float scale,
-        const float max_bias,
-        const float m0,
-        const float m1,
-        const uint32_t n_head_log2,
-        const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int nb31,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
-#ifdef FP16_MMA_AVAILABLE
-    // Skip unused kernel variants for faster compilation:
-    if (use_logit_softcap && !(D == 128 || D == 256)) {
-        NO_DEVICE_CODE;
-        return;
-    }
-
-    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
-
-    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
-    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
-
-    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
-    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
-    constexpr int frag_m = ncols == 8 ? 32 : 16;
-    constexpr int frag_n = ncols == 8 ?  8 : 16;
-    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
-
-    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
-    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
-    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
-
-    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
-    constexpr int D_padded = D + 8;
-    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
-    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
-
-    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
-    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
-    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
-
-    const int stride_Q  = nb01 / sizeof(float);
-    const int stride_KV = nb11 / sizeof(half);
-
-    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
-    const half  slopeh = __float2half(slopef);
-    const half2 slope2 = make_half2(slopef, slopef);
-
-    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
-
-    frag_b Q_b[D/16][ncols/frag_n];
-
-    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
-    constexpr int mem_KQ = ncols*kqs_padded*kqar;
-    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
-    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
-    float * KQ_f = (float *) KQ;
-    half2 * KQ2 = (half2 *) KQ;
-
-    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
-    float       KQ_max_f[ncols/nwarps];
-    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
-
-#pragma unroll
-    for (int j = 0; j < ncols/nwarps; ++j) {
-        KQ_max_f[j] = -FLT_MAX/2.0f;
-    }
-
-    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
-    half2       KQ_max_h2[ncols/nwarps];
-    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
-
-#pragma unroll
-    for (int j = 0; j < ncols/nwarps; ++j) {
-        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
-    }
-
-    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
-    half2 * VKQ2 = (half2 *) VKQ;
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
-#pragma unroll
-        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
-                break;
-            }
-            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
-        }
-    }
-
-    // Convert Q to half and apply scale, temporarily store in KQ:
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
-#pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
-                break;
-            }
-            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
-        }
-    }
-
-    __syncthreads();
-
-    // Load Q into tensor core fragments/registers since it will be used frequently:
-#pragma unroll
-    for (int i0 = 0; i0 < D; i0 += 16) {
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
-        }
-    }
-
-    __syncthreads();
-
-    // Iterate over ne11 == previous tokens:
-    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
-        // Calculate tile of KQ:
-#pragma unroll
-        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
-            frag_c_KQ KQ_c[ncols/frag_n];
-#pragma unroll
-            for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
-            }
-#pragma unroll
-            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
-                frag_a_K K_a;
-                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
-#pragma unroll
-                for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
-                }
-            }
-#pragma unroll
-            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
-            }
-        }
-
-        __syncthreads();
-
-        // Calculate softmax for each KQ column using the current max. value.
-        // The divisor is stored in KQ_rowsum and will be applied at the end.
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-            const int j = j0 + threadIdx.y;
-
-            if (std::is_same<KQ_acc_t, float>::value) {
-                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
-
-                    if (use_logit_softcap) {
-                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
-                    }
-                }
-
-                float KQ_max_new = KQ_max_f[j0/nwarps];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
-                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
-                }
-                KQ_max_new = warp_reduce_max(KQ_max_new);
-
-                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
-                KQ_max_scale_f[j0/nwarps] = expf(diff);
-                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                    KQ_max_scale_f[j0/nwarps] = 0.0f;
-                }
-                KQ_max_f[j0/nwarps] = KQ_max_new;
-
-                float KQ_rowsum_add = 0.0f;
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
-                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
-                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
-                    }
-                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
-                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
-                }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
-
-                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
-                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
-            } else {
-                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
-
-                    if (use_logit_softcap) {
-                        // There is no dedicated tangens hyperbolicus function for half2.
-                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
-                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
-                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
-
-                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
-                    }
-                }
-
-                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
-                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
-                }
-                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
-                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
-                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
-                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
-                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
-                KQ_max_h2[j0/nwarps] = KQ_max_new;
-
-                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
-                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
-                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
-                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
-                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
-                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
-                }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
-
-                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
-                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
-            }
-        }
-
-        __syncthreads();
-
-        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-#pragma unroll
-            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
-                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
-                nvcuda::wmma::load_matrix_sync(
-                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
-                    KQ + j0*(kqar*kqs_padded) + k,
-                    kqar*kqs_padded);
-            }
-        }
-
-        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
-#pragma unroll
-        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
-#pragma unroll
-            for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
-            }
-
-#pragma unroll
-            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
-                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
-
-                frag_a_V v_a;
-                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
-#pragma unroll
-                for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
-                }
-            }
-        }
-
-        __syncthreads();
-
-        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
-#pragma unroll
-        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
-#pragma unroll
-            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync(
-                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
-                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
-                    D_padded, nvcuda::wmma::mem_col_major);
-            }
-        }
-
-        __syncthreads();
-
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-            const int j = j0 + threadIdx.y;
-
-            half2 VKQ_scale;
-            if (std::is_same<KQ_acc_t, float>::value) {
-                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
-            } else {
-                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
-            }
-
-#pragma unroll
-            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-                const int i = i0 + threadIdx.x;
-                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
-                    break;
-                }
-
-                half2 VKQ_add = make_half2(0.0f, 0.0f);
-#pragma unroll
-                for (int l = 0; l < VKQ_ratio; ++l) {
-                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
-                }
-                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
-            }
-        }
-
-        __syncthreads();
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j_VKQ = j0 + threadIdx.y;
-        if (ic0 + j_VKQ >= ne01) {
-            return;
-        }
-        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-
-        float KQ_rowsum_j;
-        if (std::is_same<KQ_acc_t, float>::value) {
-            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
-        } else {
-            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
-        }
-
-#pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
-                break;
-            }
-            float dst_val = VKQ[j_VKQ*D_padded + i];
-            if (parallel_blocks == 1) {
-                dst_val /= KQ_rowsum_j;
-            }
-            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
-        }
-
-        if (parallel_blocks == 1 || threadIdx.x != 0) {
-            continue;
-        }
-
-        float2 dst_meta_val;
-        if (std::is_same<KQ_acc_t, float>::value) {
-            dst_meta_val.x = KQ_max_f[j0/nwarps];
-        } else {
-            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
-        }
-        dst_meta_val.y = KQ_rowsum_j;
-        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
-    }
-#else
-   NO_DEVICE_CODE;
-#endif // FP16_MMA_AVAILABLE
-}
-
-constexpr int get_max_power_of_2(int x) {
-    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
-}
-
-static_assert(get_max_power_of_2(1) == 1, "Test failed.");
-static_assert(get_max_power_of_2(2) == 2, "Test failed.");
-static_assert(get_max_power_of_2(4) == 4, "Test failed.");
-static_assert(get_max_power_of_2(6) == 2, "Test failed.");
-
-// Number of VKQ rows calculated in parallel:
-constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
-    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
-}
-
-static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
-static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
-static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
-static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
-static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
-static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
-
-template <int D, int cols_per_block, typename KQ_acc_t>
-void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
-
-    constexpr int nwarps = 4;
-
-    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
-    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
-    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
-
-    float logit_softcap;
-    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
-
-    if (4*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 4;
-        fattn_kernel_t fattn_kernel;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        } else {
-            constexpr bool use_logit_softcap = true;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        }
-        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
-        return;
-    }
-    if (2*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 2;
-        fattn_kernel_t fattn_kernel;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        } else {
-            constexpr bool use_logit_softcap = true;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        }
-        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
-        return;
-    }
-    constexpr int parallel_blocks = 1;
-    fattn_kernel_t fattn_kernel;
-    if (logit_softcap == 0.0f) {
-        constexpr bool use_logit_softcap = false;
-        fattn_kernel = flash_attn_ext_f16<
-            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-    } else {
-        constexpr bool use_logit_softcap = true;
-        fattn_kernel = flash_attn_ext_f16<
-            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-    }
-    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
-}
-
-#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t)                         \
-    template void ggml_cuda_flash_attn_ext_wmma_f16_case                              \
-    <D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
-// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64,  8, half);
-extern DECL_FATTN_WMMA_F16_CASE( 96,  8, half);
-extern DECL_FATTN_WMMA_F16_CASE(128,  8, half);
-extern DECL_FATTN_WMMA_F16_CASE(256,  8, half);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
+void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 0b26b0f8e0595c5bc2a7c5faa1e1667bd048664d..b1e66d470832c79c2827dcc5dcf52a064c21e9a3 100644 (file)
@@ -1,5 +1,6 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
+#include "fattn-mma-f16.cuh"
 #include "fattn-tile-f16.cuh"
 #include "fattn-tile-f32.cuh"
 #include "fattn-vec-f16.cuh"
 #include "fattn-wmma-f16.cuh"
 #include "fattn.cuh"
 
-#include <cstdint>
+template <int cols_per_block>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
 
-static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
-
-    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
-
-    if (prec != GGML_PREC_DEFAULT) {
-        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
-            constexpr int cols_per_block = 16;
-            switch (Q->ne[0]) {
-                case 64:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
-                    break;
-                case 80:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
-                    break;
-                case 96:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
-                    break;
-                case 112:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
-                    break;
-                case 128:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
-                    break;
-                case 256:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
-                    break;
-                default:
-                    GGML_ABORT("fatal error");
-                    break;
-            }
-        } else {
-            constexpr int cols_per_block = 32;
-            switch (Q->ne[0]) {
-                case 64:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
-                    break;
-                case 80:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
-                    break;
-                case 96:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
-                    break;
-                case 112:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
-                    break;
-                case 128:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
-                    break;
-                // case 256:
-                //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
-                //     break;
-                default:
-                    GGML_ABORT("fatal error");
-                    break;
-            }
-        }
-        return;
-    }
-
-    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
-        constexpr int cols_per_block = 8;
-        switch (Q->ne[0]) {
-            case 64:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
-                break;
-            case 96:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
-                break;
-            case 128:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
-                break;
-            case 256:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
-                break;
-            default:
-                GGML_ABORT("fatal error");
-                break;
-        }
-        return;
-    }
-
-    if (Q->ne[1] <= 32) {
-        constexpr int cols_per_block = 16;
-        switch (Q->ne[0]) {
-            case 64:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
-                break;
-            case 80:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
-                break;
-            case 96:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
-                break;
-            case 112:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
-                break;
-            case 128:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
-                break;
-            case 256:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
-                break;
-            default:
-                GGML_ABORT("fatal error");
-                break;
-        }
-        return;
-    }
-
-    constexpr int cols_per_block = 32;
     switch (Q->ne[0]) {
         case 64:
-            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
             break;
         case 80:
-            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
             break;
         case 96:
-            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
             break;
         case 112:
-            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
             break;
         case 128:
-            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
             break;
         case 256:
-            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
             break;
         default:
             GGML_ABORT("fatal error");
             break;
     }
 }
+
+static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+
+    if (Q->ne[1] <= 8) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] <= 16) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
+        return;
+    }
+
+    ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
+}
+
 #define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \
     if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \
         ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
@@ -322,11 +235,19 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         return;
     }
 
-    if (!fp16_mma_available(cc)) {
-        if (Q->ne[1] <= 8) {
-            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+    if (!new_mma_available(cc)) {
+        if (prec == GGML_PREC_DEFAULT) {
+            if (Q->ne[1] <= 8) {
+                ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+            } else {
+                ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+            }
         } else {
-            ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+            if (Q->ne[1] <= 8) {
+                ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+            } else {
+                ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
+            }
         }
         return;
     }
@@ -341,5 +262,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         }
     }
 
-    ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+    // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
+    if (cc == GGML_CUDA_CC_VOLTA) {
+        ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+    }
+
+    ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
 }
index 7d11540afd2371730a6d88bc880ce905d57c5b23..9788a1389a35df58862034994f4d2b3824862f0e 100644 (file)
@@ -1,11 +1,67 @@
+// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
+// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
+// The documentation for the PTX instructions can be found under:
+//   https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
+//
+// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
+// A is a row-major matrix with shape I x K.
+// B is a column-major matrix with shape K x J.
+// C is a column-major matrix with shape I x J.
+// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
+// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
+// All matrix tiles have ne physical 32 bit elements per warp.
+//
+// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
+
 #include "common.cuh"
 
-struct mma_int_A_I16K4 {
+
+#if CUDART_VERSION >= 11800
+
+static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
+    int ret = 0;
+
+#ifdef NEW_MMA_AVAILABLE
+    asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
+        : "+r"(ret) : "r"(x));
+#else
+    NO_DEVICE_CODE;
+#endif // defined(NEW_MMA_AVAILABLE)
+    return ret;
+}
+
+#else
+
+static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
+    // Imagine transposing row-major matrix to column-major matrix.
+    const int src_i_low  = 2 * (threadIdx.x % 4);
+    const int src_i_high = src_i_low + 1;
+    const int src_j      = threadIdx.x / 4;
+
+    const int src_laneid_low  = src_i_low  * 4 + src_j / 2;
+    const int src_laneid_high = src_i_high * 4 + src_j / 2;
+
+    const int shift_low  = ((src_j + 0) % 2) * 16;
+    const int shift_high = ((src_j + 1) % 2) * 16;
+
+    const int ret_low  = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low,  WARP_SIZE) >> shift_low)  & 0x0000FFFF;
+    const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
+
+    return ret_low | ret_high;
+}
+
+#endif // CUDART_VERSION >= 11800
+
+
+template <typename T>
+struct mma_A_I16K4 {
+    static_assert(sizeof(T) == 4, "bad type size");
+
     static constexpr int I  = 16;
     static constexpr int K  = 4;
     static constexpr int ne = 2;
 
-    int x[ne] = {0};
+    T x[ne];
 
     static __device__ __forceinline__ int get_i(const int l) {
         const int ret = (l%2) * (I/2) + threadIdx.x / K;
@@ -21,27 +77,35 @@ struct mma_int_A_I16K4 {
         return ret;
     }
 
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE)
-        const int * xs = xs0 + (threadIdx.x%I)*stride;
-        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
-            : "+r"(x[0]), "+r"(x[1])
-            : "l"(xs));
-#else
+    __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
 #pragma unroll
         for (int l = 0; l < ne; ++l) {
             x[l] = xs0[get_i(l)*stride + get_k(l)];
         }
-#endif // defined(INT8_MMA_AVAILABLE)
+    }
+
+    __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int *) x;
+        const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(xi[0]), "+r"(xi[1])
+            : "l"(xs));
+#else
+        load_generic(xs0, stride);
+#endif // NEW_MMA_AVAILABLE
     }
 };
 
-struct mma_int_A_I16K8 {
+template <typename T>
+struct mma_A_I16K8 {
+    static_assert(sizeof(T) == 4, "bad type size");
+
     static constexpr int I  = 16;
     static constexpr int K  = 8;
     static constexpr int ne = 4;
 
-    int x[ne] = {0};
+    T x[ne];
 
     static __device__ __forceinline__ int get_i(const int l) {
         const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
@@ -57,31 +121,62 @@ struct mma_int_A_I16K8 {
         return ret;
     }
 
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE)
-        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
-        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
-            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
-            : "l"(xs));
-#else
+    __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
 #pragma unroll
         for (int l = 0; l < ne; ++l) {
             x[l] = xs0[get_i(l)*stride + get_k(l)];
         }
-#endif // defined(INT8_MMA_AVAILABLE)
     }
 
-    __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
-        ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
+    __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int * ) x;
+        const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+            : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+            : "l"(xs));
+#else
+        GGML_UNUSED(xs0);
+        GGML_UNUSED(stride);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
+    __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int * ) x;
+        const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+        asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
+            : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
+            : "l"(xs));
+#else
+        GGML_UNUSED(xs0);
+        GGML_UNUSED(stride);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
+    __device__ __forceinline__ void transpose() {
+        int * xi  = (int *) x;
+        xi[0] = ggml_cuda_movmatrix(xi[0]);
+
+        const int tmp = ggml_cuda_movmatrix(xi[1]);
+        xi[1] = ggml_cuda_movmatrix(xi[2]);
+        xi[2] = tmp;
+
+        xi[3] = ggml_cuda_movmatrix(xi[3]);
     }
 };
 
-struct mma_int_B_J8K4 {
+template <typename T>
+struct mma_B_J8K4 {
+    static_assert(sizeof(T) == 4, "bad type size");
+
     static constexpr int J  = 8;
     static constexpr int K  = 4;
     static constexpr int ne = 1;
 
-    int x[ne] = {0};
+    T x[ne];
 
     static __device__ __forceinline__ int get_j(const int /* l */) {
         const int ret = threadIdx.x / K;
@@ -97,27 +192,34 @@ struct mma_int_B_J8K4 {
         return ret;
     }
 
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
-        const int * xs = xs0 + (threadIdx.x%J)*stride;
-        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
-            : "+r"(x[0])
-            : "l"(xs));
-#else
+    __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
 #pragma unroll
         for (int l = 0; l < ne; ++l) {
             x[l] = xs0[get_j(l)*stride + get_k(l)];
         }
-#endif // defined(INT8_MMA_AVAILABLE)
+    }
+
+    __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int *) x;
+        const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
+            : "+r"(xi[0]) : "l"(xs));
+#else
+        load_generic(xs0, stride);
+#endif // NEW_MMA_AVAILABLE
     }
 };
 
-struct mma_int_B_J8K8 {
+template <typename T>
+struct mma_B_J8K8 {
+    static_assert(sizeof(T) == 4, "bad type size");
+
     static constexpr int J  = 8;
     static constexpr int K  = 8;
     static constexpr int ne = 2;
 
-    int x[ne] = {0};
+    T x[ne];
 
     static __device__ __forceinline__ int get_j(const int /* l */) {
         const int ret = threadIdx.x / (K/2);
@@ -133,22 +235,31 @@ struct mma_int_B_J8K8 {
         return ret;
     }
 
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
-        const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
-        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
-            : "+r"(x[0]), "+r"(x[1])
-            : "l"(xs));
-#else
+    __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
 #pragma unroll
         for (int l = 0; l < ne; ++l) {
             x[l] = xs0[get_j(l)*stride + get_k(l)];
         }
-#endif // defined(INT8_MMA_AVAILABLE)
+    }
+
+    __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int *) x;
+        const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(xi[0]), "+r"(xi[1])
+            : "l"(xs));
+#else
+        load_generic(xs0, stride);
+#endif // NEW_MMA_AVAILABLE
     }
 };
 
-struct mma_int_C_I16J8 {
+template <typename T>
+struct mma_C_I16J8 {};
+
+template <>
+struct mma_C_I16J8<int> {
     static constexpr int I  = 16;
     static constexpr int J  = 8;
     static constexpr int ne = 4;
@@ -169,8 +280,8 @@ struct mma_int_C_I16J8 {
         return ret;
     }
 
-    __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
-#ifdef INT8_MMA_AVAILABLE
+    __device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
+#ifdef NEW_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
             : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
@@ -188,11 +299,11 @@ struct mma_int_C_I16J8 {
         GGML_UNUSED(mma_A);
         GGML_UNUSED(mma_B);
         NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-    __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
-#ifdef INT8_MMA_AVAILABLE
+    __device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
+#ifdef NEW_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
             : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
@@ -216,6 +327,132 @@ struct mma_int_C_I16J8 {
         GGML_UNUSED(mma_A);
         GGML_UNUSED(mma_B);
         NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
+    }
+};
+
+template <>
+struct mma_C_I16J8<half2> {
+    static constexpr int I  = 16;
+    static constexpr int J  = 4;
+    static constexpr int ne = 2;
+
+    half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = l * (I/2) + threadIdx.x / J;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_j(const int /* l */) {
+        const int ret = threadIdx.x % J;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
+#ifdef NEW_MMA_AVAILABLE
+        int * Axi = (int *) mma_A.x;
+        int * Bxi = (int *) mma_B.x;
+        int * xi  = (int *) x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(xi[0]), "+r"(xi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(xi[0]), "+r"(xi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(xi[0]), "+r"(xi[1])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(mma_A);
+        GGML_UNUSED(mma_B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
+    __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
+        mma_B_J8K8<half2> mma_B;
+
+        int * xi   = (int *) x;
+        int * Bxi  = (int *) mma_B.x;
+        Bxi[0] = ggml_cuda_movmatrix(xi[0]);
+        Bxi[1] = ggml_cuda_movmatrix(xi[1]);
+
+        return mma_B;
+    }
+};
+
+template <>
+struct mma_C_I16J8<float> {
+    static constexpr int I  = 16;
+    static constexpr int J  = 8;
+    static constexpr int ne = 4;
+
+    float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_j(const int l) {
+        const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
+#ifdef NEW_MMA_AVAILABLE
+        int * Axi = (int *) mma_A.x;
+        int * Bxi = (int *) mma_B.x;
+        int * xi  = (int *) x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(mma_A);
+        GGML_UNUSED(mma_B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
+    __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
+        mma_B_J8K8<half2> mma_B;
+        mma_B.x[0] = make_half2(x[0], x[1]);
+        mma_B.x[1] = make_half2(x[2], x[3]);
+
+        int * Bxi  = (int *) mma_B.x;
+        Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
+        Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
+
+        return mma_B;
+    }
+
+    __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_i(l)];
+        }
     }
 };
index 270251df4f115d7aaee6dc82556a7e2aeaead23b..83cb78cbdd45111f59f321b870b39755b512502e 100644 (file)
@@ -132,7 +132,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return false;
     }
 
-    if (int8_mma_available(cc)) {
+    if (new_mma_available(cc)) {
         return true;
     }
 
index 3cd508a1d4a1c93db27368afbbe8dbcc0c27dbf7..c05c8477812b15031beed0fce0024d632cb2f8a7 100644 (file)
@@ -87,7 +87,7 @@ struct tile_x_sizes {
 };
 
 static constexpr int get_mmq_x_max_host(const int cc) {
-    return int8_mma_available(cc) ? 128 :
+    return new_mma_available(cc) ? 128 :
 #ifdef GGML_CUDA_FORCE_MMQ
         cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128                     : 64;
 #else
@@ -96,9 +96,9 @@ static constexpr int get_mmq_x_max_host(const int cc) {
 }
 
 static constexpr __device__ int get_mmq_x_max_device() {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     return 128;
-#else // INT8_MMA_AVAILABLE
+#else // NEW_MMA_AVAILABLE
 
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
     return 128;
@@ -116,7 +116,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 static constexpr int get_mmq_y_host(const int cc) {
@@ -209,10 +209,10 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
 #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
 
 static int mmq_get_granularity_host(const int mmq_x, const int cc) {
-    return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+    return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
 }
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
     return mmq_x >= 48 ? 16 : 8;
 }
@@ -220,21 +220,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
 static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
     return 8;
 }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 // ------------------------------------------------------------
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_0;
     const int kqsx = threadIdx.x % QI4_0;
@@ -250,12 +250,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b2(bxi->qs, kqsx);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
 #else
         x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@@ -271,11 +271,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -322,14 +322,14 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_1;
     const int kqsx = threadIdx.x % QI4_1;
@@ -345,12 +345,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b4(bxi->qs, kqsx);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
         x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
@@ -366,11 +366,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
 #else
         x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -417,14 +417,14 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI5_0;
     const int kqsx = threadIdx.x % QI5_0;
@@ -456,13 +456,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
@@ -478,25 +478,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI5_1;
     const int kqsx = threadIdx.x % QI5_1;
@@ -526,13 +526,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
@@ -548,25 +548,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
 #else
         x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_tile + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI8_0;
     const int kqsx = threadIdx.x % QI8_0;
@@ -581,13 +581,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
         x_qs[i*(2*WARP_SIZE + 1)     + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
@@ -603,11 +603,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = bxi->d;
 #else
         x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -645,9 +645,9 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef mma_A_I16K8<int> mma_A;
+    typedef mma_B_J8K8<int>  mma_B;
+    typedef mma_C_I16J8<int> mma_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
@@ -672,7 +672,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
             const int k0 = k00 + k01;
 
-            A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+            A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
         }
 
 #pragma unroll
@@ -695,7 +695,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
             mma_B  B;
             float dB[mma_C::ne/2];
 
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+            B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
 
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -711,7 +711,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
                 mma_C C;
-                C.mma_K8(A[n][k01/QI8_0], B);
+                C.mma(A[n][k01/QI8_0], B);
 
 #pragma unroll
                 for (int l = 0; l < mma_C::ne; ++l) {
@@ -756,9 +756,9 @@ template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef mma_A_I16K8<int> mma_A;
+    typedef mma_B_J8K8<int>  mma_B;
+    typedef mma_C_I16J8<int> mma_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
@@ -782,7 +782,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
             const int k0 = k00 + k01;
 
-            A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+            A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
         }
 
 #pragma unroll
@@ -805,7 +805,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
             mma_B    B;
             float2 dsB[mma_C::ne/2];
 
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+            B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
 
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -817,7 +817,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
                 mma_C C;
-                C.mma_K8(A[n][k01/QI8_1], B);
+                C.mma(A[n][k01/QI8_1], B);
 
 #pragma unroll
                 for (int l = 0; l < mma_C::ne; ++l) {
@@ -864,12 +864,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_int_A_I16K4 mma_A;
-    typedef mma_int_A_I16K8 mma_A_K8;
-    typedef mma_int_B_J8K4  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef mma_A_I16K4<int> mma_A;
+    typedef mma_A_I16K8<int> mma_A_K8;
+    typedef mma_B_J8K4<int>  mma_B;
+    typedef mma_C_I16J8<int> mma_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
@@ -893,7 +893,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
             const int k0 = k00 + k01;
 
-            ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+            ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
         }
 
 #pragma unroll
@@ -916,8 +916,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
             mma_B B[2];
             float dB[mma_C::ne/2];
 
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+            // Here load_generic is faster than load_ldmatrix.
+            B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
+            B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
 
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -929,8 +930,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
                 mma_C C[2];
-                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
-                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+                C[0].mma(A[n][k01/4 + 0], B[0]);
+                C[1].mma(A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
                 for (int l = 0; l < mma_C::ne; ++l) {
@@ -942,20 +943,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % QI2_K;
 
@@ -977,11 +978,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
             const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int sc_m = bxi->scales[kqsx];
@@ -992,11 +993,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
 #endif // FAST_FP16_AVAILABLE
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
 #else
         x_dm[i*(WARP_SIZE + 1)       + kqsx] = x_dm_ik;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -1051,12 +1052,12 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_int_A_I16K4 mma_A;
-    typedef mma_int_A_I16K8 mma_A_K8;
-    typedef mma_int_B_J8K4  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef mma_A_I16K4<int> mma_A;
+    typedef mma_A_I16K8<int> mma_A_K8;
+    typedef mma_B_J8K4<int>  mma_B;
+    typedef mma_C_I16J8<int> mma_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
@@ -1081,7 +1082,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
             const int k0 = k00 + k01;
 
-            ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+            ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
         }
     }
 
@@ -1118,24 +1119,25 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
             mma_B B[2];
 
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+            // Here load_generic is faster than load_ldmatrix.
+            B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
+            B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
 
             mma_C Cm[2];
             if (k01 >= WARP_SIZE * 3/4) {
                 mma_A A1;
                 A1.x[0] = 0x01010101;
                 A1.x[1] = 0x01010101;
-                Cm[0].mma_K4(A1, B[0]);
-                Cm[1].mma_K4(A1, B[1]);
+                Cm[0].mma(A1, B[0]);
+                Cm[1].mma(A1, B[1]);
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
                 mma_C Cd[2];
 
-                Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
-                Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
+                Cd[0].mma(A[n][k01/4 + 0], B[0]);
+                Cd[1].mma(A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
                 for (int l = 0; l < mma_C::ne; ++l) {
@@ -1172,13 +1174,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
@@ -1186,7 +1188,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % QI3_K;
 
@@ -1212,11 +1214,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
             const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
     }
 
@@ -1242,7 +1244,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         const int8_t * sc8 = (const int8_t *) &sc;
         const float d = bxi->d;
 
@@ -1252,10 +1254,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         }
 #else
         x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-#ifndef INT8_MMA_AVAILABLE
+#ifndef NEW_MMA_AVAILABLE
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
         int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
@@ -1268,7 +1270,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         x_df[i] = bxi->d;
     }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
@@ -1317,7 +1319,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
@@ -1325,7 +1327,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1338,15 +1340,15 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
         const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
         x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
@@ -1407,7 +1409,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
     }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
@@ -1446,7 +1448,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
 #else
@@ -1454,7 +1456,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1478,16 +1480,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
         const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kq0] = ql0 | qh0;
         x_qs[i*(2*WARP_SIZE + 1)     + kq1] = ql1 | qh1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
@@ -1548,7 +1550,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
     }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template <int mmq_x, int mmq_y, int nwarps>
@@ -1587,7 +1589,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
     int   * x_sc = (int   *) (x_df + WARP_SIZE/QI6_K);
@@ -1596,7 +1598,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1619,13 +1621,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
         const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*(2*WARP_SIZE + 1)     + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
@@ -1641,11 +1643,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q6_K       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1658,11 +1660,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
 #else
         x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -1702,11 +1704,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_int_A_I16K4 mma_A;
-    typedef mma_int_B_J8K4  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef mma_A_I16K4<int> mma_A;
+    typedef mma_B_J8K4<int>  mma_B;
+    typedef mma_C_I16J8<int> mma_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
@@ -1732,8 +1734,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
             const int k0 = k00 + k01;
 
-            A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),        MMQ_MMA_TILE_X_K_Q6_K);
-            A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+            A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),        MMQ_MMA_TILE_X_K_Q6_K);
+            A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
         }
 
 #pragma unroll
@@ -1771,8 +1773,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
             mma_B B[2];
             float dB[mma_C::ne/2];
 
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0        + k01, MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
+            // Here load_generic is faster than load_ldmatrix.
+            B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0        + k01, MMQ_TILE_Y_K);
+            B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
 
 #pragma unroll
             for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -1784,8 +1787,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
                 mma_C C[2];
-                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
-                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+                C[0].mma(A[n][k01/4 + 0], B[0]);
+                C[1].mma(A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
                 for (int l = 0; l < mma_C::ne; ++l) {
@@ -1805,20 +1808,20 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_NL;
     const int kqsx = threadIdx.x % QI4_NL;
@@ -1836,13 +1839,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int aux_q4 = get_int_b2(bxi->qs, kqsx);
         const int2 v = get_int_from_table_16(aux_q4);
         const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
@@ -1858,25 +1861,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kbxd] = __half2float(bxi->d);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI2_XXS/2);
 
@@ -1905,36 +1908,36 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
             const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
             const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/4;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI2_XS/2);
 
@@ -1959,38 +1962,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI2_S/2);
 
@@ -2022,38 +2025,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI3_XXS/2);
 
@@ -2080,36 +2083,36 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/2;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI3_S/2);
 
@@ -2143,36 +2146,36 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = ls*d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % QI1_S;
 
@@ -2198,37 +2201,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
             const int grid0 = (grid >> 0) & 0x0F0F0F0F;
             const int grid1 = (grid >> 4) & 0x0F0F0F0F;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
         const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
 #else
         x_ds[i*(WARP_SIZE/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = 0;           // threadIdx.x / QI4_XS
     const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
@@ -2246,13 +2249,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int aux_q4 = get_int_b4(bxi->qs, kqsx);
         const int2 v = get_int_from_table_16(aux_q4);
         const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -2270,11 +2273,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
             | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + threadIdx.x % 8] = d * (ls - 32);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -2307,16 +2310,16 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
 static __device__ __forceinline__ void mmq_write_back_mma(
     const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
 
-    typedef mma_int_C_I16J8 mma_C;
+    typedef mma_C_I16J8<int> mma_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
     constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
 
     const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
@@ -2505,13 +2508,13 @@ static __device__ void mul_mat_q_process_tile(
     int * tile_y = (int *) data_mul_mat_q;
     int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
     constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
 #else
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
     constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     constexpr int blocks_per_iter = MMQ_ITER_K / qk;
 
@@ -2643,7 +2646,7 @@ static __global__ void mul_mat_q(
     const int jt =  kbc /    (blocks_per_ne00*nty);
     const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
 
-    constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
+    constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
         (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
             it, jt, kb0_start, kb0_stop);
@@ -2749,7 +2752,7 @@ template<ggml_type type>
 static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
     const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
-    const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
     const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
     return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
 }
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu
new file mode 100644 (file)
index 0000000..f09bdef
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16);
+DECL_FATTN_MMA_F16_CASE(80, 16);
+DECL_FATTN_MMA_F16_CASE(96, 16);
+DECL_FATTN_MMA_F16_CASE(112, 16);
+DECL_FATTN_MMA_F16_CASE(128, 16);
+DECL_FATTN_MMA_F16_CASE(256, 16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu
new file mode 100644 (file)
index 0000000..2211088
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 32);
+DECL_FATTN_MMA_F16_CASE(80, 32);
+DECL_FATTN_MMA_F16_CASE(96, 32);
+DECL_FATTN_MMA_F16_CASE(112, 32);
+DECL_FATTN_MMA_F16_CASE(128, 32);
+DECL_FATTN_MMA_F16_CASE(256, 32);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu
new file mode 100644 (file)
index 0000000..d24b085
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64);
+DECL_FATTN_MMA_F16_CASE(80, 64);
+DECL_FATTN_MMA_F16_CASE(96, 64);
+DECL_FATTN_MMA_F16_CASE(112, 64);
+DECL_FATTN_MMA_F16_CASE(128, 64);
+DECL_FATTN_MMA_F16_CASE(256, 64);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu
new file mode 100644 (file)
index 0000000..bdf86c0
--- /dev/null
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8);
+DECL_FATTN_MMA_F16_CASE(80, 8);
+DECL_FATTN_MMA_F16_CASE(96, 8);
+DECL_FATTN_MMA_F16_CASE(112, 8);
+DECL_FATTN_MMA_F16_CASE(128, 8);
+DECL_FATTN_MMA_F16_CASE(256, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu
deleted file mode 100644 (file)
index 2d94e65..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-wmma-f16.cuh"
-
-DECL_FATTN_WMMA_F16_CASE(64, 16, float);
-DECL_FATTN_WMMA_F16_CASE(80, 16, float);
-DECL_FATTN_WMMA_F16_CASE(96, 16, float);
-DECL_FATTN_WMMA_F16_CASE(112, 16, float);
-DECL_FATTN_WMMA_F16_CASE(128, 16, float);
-DECL_FATTN_WMMA_F16_CASE(256, 16, float);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu
deleted file mode 100644 (file)
index c3d9df3..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-wmma-f16.cuh"
-
-DECL_FATTN_WMMA_F16_CASE(64, 32, float);
-DECL_FATTN_WMMA_F16_CASE(80, 32, float);
-DECL_FATTN_WMMA_F16_CASE(96, 32, float);
-DECL_FATTN_WMMA_F16_CASE(112, 32, float);
-DECL_FATTN_WMMA_F16_CASE(128, 32, float);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu
deleted file mode 100644 (file)
index bb680e4..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-wmma-f16.cuh"
-
-DECL_FATTN_WMMA_F16_CASE(64, 16, half);
-DECL_FATTN_WMMA_F16_CASE(80, 16, half);
-DECL_FATTN_WMMA_F16_CASE(96, 16, half);
-DECL_FATTN_WMMA_F16_CASE(112, 16, half);
-DECL_FATTN_WMMA_F16_CASE(128, 16, half);
-DECL_FATTN_WMMA_F16_CASE(256, 16, half);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu
deleted file mode 100644 (file)
index 073f71b..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-wmma-f16.cuh"
-
-DECL_FATTN_WMMA_F16_CASE(64, 32, half);
-DECL_FATTN_WMMA_F16_CASE(80, 32, half);
-DECL_FATTN_WMMA_F16_CASE(96, 32, half);
-DECL_FATTN_WMMA_F16_CASE(112, 32, half);
-DECL_FATTN_WMMA_F16_CASE(128, 32, half);
-DECL_FATTN_WMMA_F16_CASE(256, 32, half);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu
deleted file mode 100644 (file)
index d30710c..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-wmma-f16.cuh"
-
-DECL_FATTN_WMMA_F16_CASE(64, 8, half);
-DECL_FATTN_WMMA_F16_CASE(96, 8, half);
-DECL_FATTN_WMMA_F16_CASE(128, 8, half);
-DECL_FATTN_WMMA_F16_CASE(256, 8, half);
index d7874e6eaf83280d4e48438b4fa9f639723749f8..a2628f16e57d1fe85a755649c42b32e3804acd69 100755 (executable)
@@ -12,13 +12,13 @@ SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.p
 DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v});
 """
 
-SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
-#include "../fattn-wmma-f16.cuh"
+#include "../fattn-mma-f16.cuh"
 
 """
 
-SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
+SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
 
 TYPES_MMQ = [
     "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@@ -57,20 +57,12 @@ for vkq_size in [16, 32]:
                 with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
                     f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
 
-for kq_acc_t in ["half", "float"]:
-    for cols_per_block in [8, 16, 32]:
-        if kq_acc_t == "float" and cols_per_block == 8:
-            continue
+for cols_per_block in [8, 16, 32, 64]:
+    with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
+        f.write(SOURCE_FATTN_MMA_START)
 
-        with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f:
-            f.write(SOURCE_FATTN_WMMA_START)
-
-            for head_size in [64, 80, 96, 112, 128, 256]:
-                if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32
-                    continue
-                if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
-                    continue
-                f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
+        for head_size in [64, 80, 96, 112, 128, 256]:
+            f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
 
 for type in TYPES_MMQ:
     with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
index 8594093f052ef1355321b95f94d1c724cd6e93bd..129478ed785e7c0d0395d7a593203e30bca335c4 100644 (file)
@@ -25,6 +25,7 @@
 #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
 #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
 #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
+#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
 #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
 #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
 #define cublasCreate hipblasCreate
index 7a877bdc11a6f529c15044316ea6fbbd0d51fb51..eb03e10fa48a1b39ef444571a22c078f3213583e 100644 (file)
@@ -50,7 +50,7 @@ file(GLOB   GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
 list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
 
 file(GLOB   GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
-file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
+file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
 file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
index 415b2b2e09c1b3514a30bd2d1cd90d372bf3a2c8..2f555416e62cf3ab7df050e03438ea5e452641b9 100644 (file)
@@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND)
     list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
 
     file(GLOB   GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
-    file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
+    file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
     list(APPEND GGML_SOURCES_MUSA ${SRCS})
     file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
     list(APPEND GGML_SOURCES_MUSA ${SRCS})