]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 (#15131)
authorJohannes Gäßler <redacted>
Thu, 7 Aug 2025 08:53:21 +0000 (10:53 +0200)
committerGitHub <redacted>
Thu, 7 Aug 2025 08:53:21 +0000 (10:53 +0200)
* CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16

15 files changed:
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mma.cuh
ggml/src/ggml-cuda/mmf.cu [new file with mode: 0644]
ggml/src/ggml-cuda/mmf.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/mmv.cu [deleted file]
ggml/src/ggml-cuda/mmv.cuh [deleted file]
ggml/src/ggml-cuda/mmvf.cu [new file with mode: 0644]
ggml/src/ggml-cuda/mmvf.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/vendors/hip.h
ggml/src/ggml-cuda/vendors/musa.h

index 8f27255476d5981acc07d397cfb1df7bb82eae95..2e5d48797fa49350b9fc731a37746304c56cc093 100644 (file)
@@ -233,9 +233,13 @@ typedef float2 dfloat2;
 #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
 
 #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
-#define NEW_MMA_AVAILABLE
+#define TURING_MMA_AVAILABLE
 #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#define AMPERE_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+
 #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #define CP_ASYNC_AVAILABLE
 #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -303,10 +307,14 @@ static bool amd_mfma_available(const int cc) {
 }
 
 // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
-static bool new_mma_available(const int cc) {
+static bool turing_mma_available(const int cc) {
     return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
 }
 
+static bool ampere_mma_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+}
+
 static bool cp_async_available(const int cc) {
     return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
 }
index e7570f9d3b830340ed113a4c52429d304324127a..3712538441719b0158bbe74cb17c420a4e1d45c2 100644 (file)
@@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         float        * const __restrict__ KQ_max,
         float        * const __restrict__ KQ_rowsum,
         const int kb0) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
     typedef fattn_mma_f16_config<DKQ, DV> c;
 
 #ifdef CP_ASYNC_AVAILABLE
@@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
     GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
     NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
 }
 
 template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
@@ -800,7 +800,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int jt,
         const int kb0_start,
         const int kb0_stop) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     typedef fattn_mma_f16_config<DKQ, DV> c;
@@ -1196,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
     GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
     NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
 }
 
 template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
@@ -1223,7 +1223,7 @@ static __global__ void flash_attn_ext_f16(
                             const int32_t nb21, const int32_t nb22, const int64_t nb23,
                             const int32_t ne31, const int32_t ne32, const int32_t ne33,
                             const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
+#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1354,7 +1354,7 @@ static __global__ void flash_attn_ext_f16(
     GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
 }
 
 template <int DKQ, int DV, int ncols1, int ncols2>
index 656e04a47352e09e2796aacdd345dbf460564844..8ddd0415b7f8f2a4f6a7cbbc20c71952d21d5f51 100644 (file)
@@ -327,7 +327,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
     const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
-    const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
+    const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
         (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
@@ -340,7 +340,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
-    if (fp16_mma_available(cc) && !new_mma_available(cc)) {
+    if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
         return;
     }
index 60e481b95af0340128ecf94231fd4f16a60a33d7..ec7ab255188fcdbbbf84b4929b4e2eb58bb10d45 100644 (file)
@@ -22,8 +22,9 @@
 #include "ggml-cuda/fattn.cuh"
 #include "ggml-cuda/getrows.cuh"
 #include "ggml-cuda/im2col.cuh"
+#include "ggml-cuda/mmf.cuh"
 #include "ggml-cuda/mmq.cuh"
-#include "ggml-cuda/mmv.cuh"
+#include "ggml-cuda/mmvf.cuh"
 #include "ggml-cuda/mmvq.cuh"
 #include "ggml-cuda/norm.cuh"
 #include "ggml-cuda/opt-step-adamw.cuh"
@@ -2008,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
         && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
 
-    bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
+    bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+    bool use_mul_mat_f     = !ggml_is_quantized(src0->type)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@@ -2028,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
             }
 
             const int cc            = ggml_cuda_info().devices[id].cc;
+            const int warp_size     = ggml_cuda_info().devices[id].warp_size;
             use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            use_mul_mat_vec         = use_mul_mat_vec           && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
+            use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
+            use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
             any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
         }
     } else {
         const int cc            = ggml_cuda_info().devices[ctx.device].cc;
+        const int warp_size     = ggml_cuda_info().devices[ctx.device].warp_size;
         use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        use_mul_mat_vec         = use_mul_mat_vec           && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
+        use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
+        use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
         any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
     }
 
@@ -2048,15 +2055,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
     //TODO update for generic tensor parallelism
-    const int cc                     = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const int cc                 = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     bool use_batched_cublas_f16  = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
     bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
     bool use_batched_cublas_f32  = src0->type == GGML_TYPE_F32;
 
-    if (!split && use_mul_mat_vec) {
+    if (!split && use_mul_mat_vec_f) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
-        ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
+        ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
+    } else if (!split && use_mul_mat_f) {
+        ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_vec_q) {
         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_q) {
@@ -2065,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
-    } else if (use_mul_mat_vec) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
+    } else if (use_mul_mat_vec_f) {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
     } else if (use_mul_mat_vec_q) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
     } else if (use_mul_mat_q) {
@@ -2094,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
             if (ggml_is_quantized(src0->type)) {
                 ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
             } else {
-                ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
+                ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
             }
             return;
         }
@@ -3516,7 +3525,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
 #endif // FLASH_ATTN_AVAILABLE
             if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
                 const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
-                if (!new_mma_available(cc)) {
+                if (!turing_mma_available(cc)) {
                     return false;
                 }
                 const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
index a86365c6a061cd7d8f5f5ac9cd10944158730301..83ee16b27d0df448483cb04d2ef8328b3a367e87 100644 (file)
 static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
     int ret = 0;
 
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
     asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
         : "=r"(ret) : "r"(x));
 #else
     GGML_UNUSED(x);
     NO_DEVICE_CODE;
-#endif // defined(NEW_MMA_AVAILABLE)
+#endif // defined(TURING_MMA_AVAILABLE)
     return ret;
 }
 
@@ -167,6 +167,38 @@ namespace ggml_cuda_mma {
         }
     };
 
+    template <int I_, int J_>
+    struct tile<I_, J_, nv_bfloat162> {
+        static constexpr int I  = I_;
+        static constexpr int J  = J_;
+        static constexpr int ne = I * J / WARP_SIZE;
+        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return l * 8 + threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l % 2) * 8 + threadIdx.x / 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return l * 4 + threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l / 2) * 4 + threadIdx.x % 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+    };
+
     template <int I, int J>
     static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
         tile<I, J/2, half2> ret;
@@ -209,7 +241,7 @@ namespace ggml_cuda_mma {
     template <typename T>
     static __device__ __forceinline__ void load_ldmatrix(
             tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         int * xi = (int *) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
         asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -217,13 +249,13 @@ namespace ggml_cuda_mma {
             : "l"(xs));
 #else
         load_generic(t, xs0, stride);
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     template <typename T>
     static __device__ __forceinline__ void load_ldmatrix(
             tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         int * xi = (int *) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
         asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -232,13 +264,13 @@ namespace ggml_cuda_mma {
 #else
         load_generic(xs0, stride);
         GGML_UNUSED(t);
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     template <typename T>
     static __device__ __forceinline__ void load_ldmatrix(
             tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
-#if defined(NEW_MMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE)
         int * xi = (int * ) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
         asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
@@ -246,13 +278,13 @@ namespace ggml_cuda_mma {
             : "l"(xs));
 #else
         load_generic(t, xs0, stride);
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     template <typename T>
     static __device__ __forceinline__ void load_ldmatrix_trans(
             tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         int * xi = (int * ) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
         asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
@@ -263,12 +295,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(xs0);
         GGML_UNUSED(stride);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
             : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -287,12 +319,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
             : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -317,12 +349,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -344,12 +376,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -380,12 +412,29 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
+#ifdef AMPERE_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // AMPERE_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -407,12 +456,29 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
+#ifdef AMPERE_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // AMPERE_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -443,7 +509,7 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu
new file mode 100644 (file)
index 0000000..1437367
--- /dev/null
@@ -0,0 +1,431 @@
+#include "ggml.h"
+#include "common.cuh"
+#include "mma.cuh"
+#include "mmf.cuh"
+
+using namespace ggml_cuda_mma;
+
+#define MMF_ROWS_PER_BLOCK 32
+
+template <typename T, int rows_per_block, int cols_per_block, int nwarps>
+__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
+static __global__ void mul_mat_f(
+        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
+        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile< 8, 8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+    constexpr int tile_k_padded = warp_size + 4;
+    constexpr int ntA = rows_per_block / tile_A::I;
+    constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
+
+    const int row0        = blockIdx.x * rows_per_block;
+    const int channel_dst = blockIdx.y;
+    const int channel_x   = channel_dst / channel_ratio;
+    const int channel_y   = channel_dst;
+    const int sample_dst  = blockIdx.z;
+    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_y    = sample_dst;
+
+    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row0*stride_row ;
+    y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
+    dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
+
+    const float2 * y2 = (const float2 *) y;
+
+    extern __shared__ char data_mmv[];
+
+    tile_C C[ntA][ntB];
+
+    T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
+
+    for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
+        tile_A A[ntA][warp_size / tile_A::J];
+#pragma unroll
+        for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+            for (int i = 0; i < tile_A::I; ++i) {
+                tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col];
+            }
+#pragma unroll
+            for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
+                load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
+            }
+        }
+
+#pragma unroll
+        for (int itB = 0; itB < ntB; ++itB) {
+            if constexpr (std::is_same_v<T, float>) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const int j = j0 + itB*tile_B::I;
+
+                    tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
+                }
+            } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const int j = j0 + itB*tile_B::I;
+
+                    const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
+                    tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+                }
+            } else {
+                static_assert(std::is_same_v<T, void>, "unsupported type");
+            }
+#pragma unroll
+            for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+                tile_B B;
+                load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+                for (int itA = 0; itA < ntA; ++itA) {
+                    mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+                }
+            }
+        }
+    }
+
+    float * buf_iw = (float *) data_mmv;
+    constexpr int kiw = nwarps*rows_per_block + 4;
+
+    if (nwarps > 1) {
+        __syncthreads();
+    }
+#pragma unroll
+    for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+        for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+            for (int l = 0; l < tile_C::ne; ++l) {
+                const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
+                const int j = itB*tile_C::J + tile_C::get_j(l);
+                buf_iw[j*kiw + i] = C[itA][itB].x[l];
+            }
+        }
+    }
+
+    if (nwarps > 1) {
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
+            return;
+        }
+
+        float sum = 0.0f;
+        static_assert(rows_per_block == warp_size, "need loop/check");
+#pragma unroll
+        for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
+            const int i = i0 + threadIdx.x;
+
+            sum += buf_iw[j*kiw + i];
+        }
+        dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
+    }
+#else
+    NO_DEVICE_CODE;
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst);
+    GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst);
+    GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst);
+    GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst);
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+}
+
+template <typename T, int cols_per_block>
+static void mul_mat_f_cuda(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nrows_x,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile< 8, 8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+
+    GGML_ASSERT(!ids && "mul_mat_id not implemented");
+
+    GGML_ASSERT(ncols_x      % 2 == 0);
+    GGML_ASSERT(stride_row   % 2 == 0);
+    GGML_ASSERT(stride_col_y % 2 == 0);
+    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const int64_t channel_ratio = nchannels_dst / nchannels_x;
+    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
+
+    const int device = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[device].warp_size;
+
+    int64_t nwarps_best     = 1;
+    int64_t niter_best      = (ncols_x + warp_size*2 - 1) / (warp_size*2);
+    int64_t max_block_size  = 256;
+    for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
+        const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
+        if (niter < niter_best) {
+            niter_best  = niter;
+            nwarps_best = nwarps;
+        }
+    }
+
+    constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
+    const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
+    const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
+    const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
+    const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
+    const dim3 block_dims(warp_size, nwarps_best, 1);
+    switch (nwarps_best) {
+        case 1: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 2: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 3: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 4: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 5: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 6: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 7: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 8: {
+            mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+template <typename T>
+static void mul_mat_f_switch_cols_per_block(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    switch (ncols_dst) {
+        case  1: {
+            mul_mat_f_cuda<T,  1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  2: {
+            mul_mat_f_cuda<T,  2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  3: {
+            mul_mat_f_cuda<T,  3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  4: {
+            mul_mat_f_cuda<T,  4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  5: {
+            mul_mat_f_cuda<T,  5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  6: {
+            mul_mat_f_cuda<T,  6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  7: {
+            mul_mat_f_cuda<T,  7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  8: {
+            mul_mat_f_cuda<T,  8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  9: {
+            mul_mat_f_cuda<T,  9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 10: {
+            mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 11: {
+            mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 12: {
+            mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 13: {
+            mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 14: {
+            mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 15: {
+            mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 16: {
+            mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
+    GGML_ASSERT(         dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
+
+    GGML_ASSERT(ne13 == ne3);
+
+    GGML_ASSERT(        nb00       == ts_src0);
+    GGML_ASSERT(        nb10       == ts_src1);
+    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+    GGML_ASSERT(        nb0        == ts_dst);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+    const float   * src1_d =       (const float   *) src1->data;
+    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
+    float         *  dst_d =       (float         *)  dst->data;
+
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s11 = src1->nb[1] / ts_src1;
+    const int64_t s1  =  dst->nb[1] / ts_dst;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s12 = src1->nb[2] / ts_src1;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s13 = src1->nb[3] / ts_src1;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
+
+    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+    const int64_t ncols_dst          = ids ? ne2  : ne1;
+    const int64_t nchannels_y        = ids ? ne11 : ne12;
+    const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_channel_dst = ids ? s1   : s2;
+    const int64_t stride_channel_y   = ids ? s11  : s12;
+
+    GGML_ASSERT(!ids || ncols_dst == 1);
+
+    switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0->data;
+            constexpr int vals_per_T = 1;
+            mul_mat_f_switch_cols_per_block(
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
+                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+        } break;
+        case GGML_TYPE_F16: {
+            const half2 * src0_d = (const half2 *) src0->data;
+            constexpr int vals_per_T = 2;
+            mul_mat_f_switch_cols_per_block(
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
+                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
+            constexpr int vals_per_T = 2;
+            mul_mat_f_switch_cols_per_block(
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
+                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
+}
+
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
+    if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
+        return false;
+    }
+    if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
+        return false;
+    }
+    if (ne11 > 16) {
+        return false;
+    }
+    switch (type) {
+        case GGML_TYPE_F32:
+            return ampere_mma_available(cc);
+        case GGML_TYPE_F16:
+            return turing_mma_available(cc);
+        case GGML_TYPE_BF16:
+            return ampere_mma_available(cc);
+        default:
+            return false;
+    }
+}
diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh
new file mode 100644 (file)
index 0000000..785f9f2
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
index 8954a3831045630692344d7d9745dc8418ae928d..384ee7615f7a4c9a48fa0f1202cdff9348bcf8c5 100644 (file)
@@ -310,7 +310,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return false;
     }
 
-    if (new_mma_available(cc)) {
+    if (turing_mma_available(cc)) {
         return true;
     }
 
index 1634725c20a5aa630986e32af07aa154db3387bf..96129bd831fd473a0737360415a1cc71029ca6da 100644 (file)
@@ -92,7 +92,7 @@ struct tile_x_sizes {
 };
 
 static int get_mmq_x_max_host(const int cc) {
-    return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 :
+    return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
         GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
 #ifdef GGML_CUDA_FORCE_MMQ
             128                     : 64;
@@ -102,9 +102,9 @@ static int get_mmq_x_max_host(const int cc) {
 }
 
 static constexpr __device__ int get_mmq_x_max_device() {
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     return 128;
-#else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
 #if defined(GGML_USE_HIP)
     return 64;
@@ -121,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
 #endif // defined(GGML_USE_HIP)
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 }
 
 static int get_mmq_y_host(const int cc) {
@@ -233,7 +233,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
 static int mmq_get_granularity_host(const int mmq_x, const int cc) {
     if (amd_mfma_available(cc)) {
         return mmq_x >= 128 ? 32 : 16;
-    } else if (new_mma_available(cc) && mmq_x >= 48) {
+    } else if (turing_mma_available(cc) && mmq_x >= 48) {
         return 16;
     } else {
         return 8;
@@ -244,7 +244,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) {
 static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
     return mmq_x >= 128 ? 32 : 16;
 }
-#elif defined(NEW_MMA_AVAILABLE)
+#elif defined(TURING_MMA_AVAILABLE)
 static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
     return mmq_x >= 48 ? 16 : 8;
 }
@@ -279,14 +279,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
     constexpr int nrows = warp_size / threads_per_row;
@@ -305,12 +305,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b2(bxi->qs, kqsx);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
 #else
         x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
@@ -327,11 +327,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d;
 #else
         x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -382,14 +382,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
     constexpr int nrows = warp_size / threads_per_row;
@@ -408,12 +408,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b4(bxi->qs, kqsx);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
         x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
@@ -430,11 +430,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm;
 #else
         x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -485,14 +485,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
     constexpr int nrows = warp_size / threads_per_row;
@@ -527,13 +527,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
@@ -550,11 +550,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d;
 #else
         x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -563,14 +563,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
     constexpr int nrows = warp_size / threads_per_row;
@@ -603,13 +603,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
@@ -626,11 +626,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm;
 #else
         x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -639,14 +639,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
     constexpr int threads_per_row = 32;
@@ -665,13 +665,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx);
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
@@ -688,11 +688,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0                 + kbxd] = bxi->d;
 #else
         x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -701,14 +701,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
     constexpr int nrows = warp_size / threads_per_row;
@@ -730,13 +730,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
         const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0]        = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]        = v.x;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
@@ -753,11 +753,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_1                 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
 #else
         x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -1178,7 +1178,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
             }
         }
     }
-#elif defined(NEW_MMA_AVAILABLE)
+#elif defined(TURING_MMA_AVAILABLE)
 
     typedef tile<16, 4, int> tile_A;
     typedef tile<16, 8, int> tile_A_8;
@@ -1264,14 +1264,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
     constexpr int nwarps = mmq_get_nwarps_device();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
     constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
@@ -1295,11 +1295,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
             const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int sc_m = bxi->scales[kqsx];
@@ -1310,11 +1310,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
 #endif // FAST_FP16_AVAILABLE
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
 #else
         x_dm[i*(MMQ_TILE_NE_K + 1)   + kqsx] = x_dm_ik;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -1452,7 +1452,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
             }
         }
     }
-#elif defined(NEW_MMA_AVAILABLE)
+#elif defined(TURING_MMA_AVAILABLE)
 
     typedef tile<16, 4, int> tile_A;
     typedef tile<16, 8, int> tile_A_8;
@@ -1582,7 +1582,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
@@ -1590,7 +1590,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
     constexpr int nrows = warp_size / threads_per_row;
@@ -1618,11 +1618,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
             const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
     }
 
@@ -1649,7 +1649,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         const int8_t * sc8 = (const int8_t *) &sc;
         const float d = bxi->d;
 
@@ -1659,10 +1659,10 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         }
 #else
         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-#if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
+#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
         int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
@@ -1675,7 +1675,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         x_df[i] = bxi->d;
     }
-#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
+#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
 }
 
 template <int mmq_x, int mmq_y>
@@ -1728,7 +1728,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
@@ -1736,7 +1736,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
     constexpr int nrows = warp_size / threads_per_row;
@@ -1753,15 +1753,15 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
         const int qs0 = get_int_b4(bxi->qs, txi);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
         x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     constexpr int rows_per_warp = warp_size / 2;
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
@@ -1829,7 +1829,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
     }
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 }
 
 template <int mmq_x, int mmq_y>
@@ -1872,7 +1872,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
 #else
@@ -1880,7 +1880,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
     constexpr int nrows = warp_size / threads_per_row;
@@ -1908,16 +1908,16 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
         const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     constexpr int rows_per_warp = warp_size / 2;
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
@@ -1986,7 +1986,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
     }
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 }
 
 template <int mmq_x, int mmq_y>
@@ -2029,7 +2029,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
     int   * x_sc = (int   *) (x_df + MMQ_TILE_NE_K/QI6_K);
@@ -2038,7 +2038,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
     constexpr int nrows = warp_size / threads_per_row;
@@ -2065,13 +2065,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
         const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
 #pragma unroll
@@ -2084,11 +2084,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q6_K]           = bxi->d;
 #else
         x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
     constexpr int rows_per_warp = warp_size / 4;
@@ -2102,11 +2102,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
 #else
         x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -2199,7 +2199,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
             }
         }
     }
-#elif defined(NEW_MMA_AVAILABLE)
+#elif defined(TURING_MMA_AVAILABLE)
 
     typedef tile<16, 4, int> tile_A;
     typedef tile< 8, 4, int> tile_B;
@@ -2311,14 +2311,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
     constexpr int nrows = warp_size / threads_per_row;
@@ -2340,13 +2340,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
         const int k0 = kbx * (2 * QI4_NL) + kqsx;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0]      = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]      = v.x;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
     constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
@@ -2363,11 +2363,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
         const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = __half2float(bxi->d);
 #else
         x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -2376,14 +2376,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
     constexpr int nrows = warp_size / threads_per_row;
@@ -2414,22 +2414,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
             const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = (ls*d + d/2)/4;
 #else
         x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -2438,14 +2438,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
     constexpr int nrows = warp_size / threads_per_row;
@@ -2472,24 +2472,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -2498,14 +2498,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
     constexpr int nrows = warp_size / threads_per_row;
@@ -2539,24 +2539,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -2565,14 +2565,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
     constexpr int nrows = warp_size / threads_per_row;
@@ -2601,22 +2601,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = (ls*d + d/2)/2;
 #else
         x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = (ls*d + d/2)/2;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -2625,14 +2625,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
     constexpr int nrows = warp_size / threads_per_row;
@@ -2668,22 +2668,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
         const float d = bxi->d;
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = ls*d;
 #else
         x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = ls*d;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -2692,14 +2692,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
     constexpr int nrows = warp_size / threads_per_row;
@@ -2727,23 +2727,23 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
             const int grid0 = (grid >> 0) & 0x0F0F0F0F;
             const int grid1 = (grid >> 4) & 0x0F0F0F0F;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
 #else
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
         const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_ds[i*MMQ_MMA_TILE_X_K_Q8_1     + kqsx] = make_half2(d1q, d1q*delta);
 #else
         x_ds[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -2752,14 +2752,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     constexpr int nwarps = mmq_get_nwarps_device();
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
     constexpr int nrows = warp_size / threads_per_row;
@@ -2779,13 +2779,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
         const int k0 = 8 * (kqsx / 4) + kqsx % 4;
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
         x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
     constexpr int rows_per_warp = warp_size / 8;
@@ -2804,11 +2804,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
             | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + threadIdx.x % 8] = d * (ls - 32);
 #else
         x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
@@ -2859,9 +2859,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(
     constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
     const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
-#if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
     static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
@@ -3061,13 +3061,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
     int * tile_y = data_mul_mat_q + mmq_x;
     int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
 
-#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
     constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
 #else
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
     constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
-#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int blocks_per_iter = MMQ_ITER_K / qk;
 
@@ -3534,7 +3534,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
     const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
     const size_t nbs_ids = mmq_x*sizeof(int);
-    const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
     const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
     return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
 }
diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu
deleted file mode 100644 (file)
index e14c935..0000000
+++ /dev/null
@@ -1,506 +0,0 @@
-#include "ggml.h"
-#include "common.cuh"
-#include "mmv.cuh"
-
-template <typename T, typename type_acc, int ncols_dst, int block_size>
-static __global__ void mul_mat_vec(
-        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
-        const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
-        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
-        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
-    const int row         = blockIdx.x;
-    const int channel_dst = blockIdx.y;
-    const int channel_x   = ids ? ids[channel_dst]          : channel_dst / channel_ratio;
-    const int channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
-    const int sample_dst  = blockIdx.z;
-    const int sample_x    = sample_dst / sample_ratio;
-    const int sample_y    = sample_dst;
-    const int tid         = threadIdx.x;
-
-    constexpr int warp_size   = ggml_cuda_get_physical_warp_size();
-
-    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
-    y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
-    dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
-
-    const float2 * y2 = (const float2 *) y;
-
-    extern __shared__ char data_mmv[];
-    float * buf_iw = (float *) data_mmv;
-
-    if (block_size > warp_size) {
-        if (tid < warp_size) {
-            buf_iw[tid] = 0.0f;
-        }
-        __syncthreads();
-    }
-
-    float sumf[ncols_dst] = {0.0f};
-
-    if constexpr (std::is_same<T, float>::value) {
-        const float2 * x2 = (const float2 *) x;
-
-        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
-            const float2 tmpx = x2[col2];
-
-#pragma unroll
-            for (int j = 0; j < ncols_dst; ++j) {
-                const float2 tmpy = y2[j*stride_col_y2 + col2];
-                sumf[j] += tmpx.x*tmpy.x;
-                sumf[j] += tmpx.y*tmpy.y;
-            }
-        }
-    } else if constexpr (std::is_same<T, half>::value) {
-        const half2 * x2 = (const half2 *) x;
-
-        if (std::is_same<type_acc, float>::value) {
-            for (int col2 = tid; col2 < ncols2; col2 += block_size) {
-                const float2 tmpx = __half22float2(x2[col2]);
-
-#pragma unroll
-                for (int j = 0; j < ncols_dst; ++j) {
-                    const float2 tmpy = y2[j*stride_col_y2 + col2];
-                    sumf[j] += tmpx.x * tmpy.x;
-                    sumf[j] += tmpx.y * tmpy.y;
-                }
-            }
-        } else {
-#ifdef FP16_AVAILABLE
-            half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
-
-            for (int col2 = tid; col2 < ncols2; col2 += block_size) {
-                const half2 tmpx = x2[col2];
-
-#pragma unroll
-                for (int j = 0; j < ncols_dst; ++j) {
-                    const float2 tmpy = y2[j*stride_col_y2 + col2];
-                    sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
-                }
-            }
-
-#pragma unroll
-            for (int j = 0; j < ncols_dst; ++j) {
-                sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
-            }
-#else
-            NO_DEVICE_CODE;
-#endif // FP16_AVAILABLE
-        }
-    } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
-        const int * x2 = (const int *) x;
-        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
-            const int tmpx = x2[col2];
-#pragma unroll
-            for (int j = 0; j < ncols_dst; ++j) {
-                const float2 tmpy = y2[j*stride_col_y2 + col2];
-                sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
-                sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
-            }
-        }
-    } else {
-        static_assert(std::is_same<T, void>::value, "unsupported type");
-    }
-
-#pragma unroll
-    for (int j = 0; j < ncols_dst; ++j) {
-        sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
-
-        if (block_size > warp_size) {
-            buf_iw[tid/warp_size] = sumf[j];
-            __syncthreads();
-            if (tid < warp_size) {
-                sumf[j] = buf_iw[tid];
-                sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
-            }
-            if (j < ncols_dst) {
-                __syncthreads();
-            }
-        }
-    }
-
-    if (tid >= ncols_dst) {
-        return;
-    }
-
-    dst[tid*stride_col_dst + row] = sumf[tid];
-}
-
-template <typename T, typename type_acc, int ncols_dst>
-static void launch_mul_mat_vec_cuda(
-        const T * x, const float * y, const int32_t * ids, float * dst,
-        const int64_t ncols, const int64_t nrows,
-        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
-        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
-        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
-    GGML_ASSERT(ncols        % 2 == 0);
-    GGML_ASSERT(stride_row   % 2 == 0);
-    GGML_ASSERT(stride_col_y % 2 == 0);
-    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
-    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
-    const int64_t channel_ratio = nchannels_dst / nchannels_x;
-    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
-    int device;
-    int warp_size;
-
-    CUDA_CHECK(cudaGetDevice(&device));
-    warp_size = ggml_cuda_info().devices[device].warp_size;
-
-    int64_t block_size_best = warp_size;
-    int64_t niter_best      = (ncols + 2*warp_size - 1) / (2*warp_size);
-    int64_t max_block_size  = 256;
-    if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
-        max_block_size = 128;
-    }
-    for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
-        const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
-        if (niter < niter_best) {
-            niter_best      = niter;
-            block_size_best = block_size;
-        }
-    }
-
-    const int smem = warp_size*sizeof(float);
-    const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
-    const dim3 block_dims(block_size_best, 1, 1);
-    switch (block_size_best) {
-        case   32: {
-            mul_mat_vec<T, type_acc, ncols_dst,  32><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case   64: {
-            mul_mat_vec<T, type_acc, ncols_dst,  64><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case   96: {
-            mul_mat_vec<T, type_acc, ncols_dst,  96><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case  128: {
-            mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case  160: {
-            mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case  192: {
-            mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case  224: {
-            mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case  256: {
-            mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        default: {
-            GGML_ABORT("fatal error");
-        } break;
-    }
-}
-
-template <typename T, typename type_acc>
-static void mul_mat_vec_cuda_switch_ncols_dst(
-        const T * x, const float * y, const int32_t * ids, float * dst,
-        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
-        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
-        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
-        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
-    switch (ncols_dst) {
-        case 1:
-            launch_mul_mat_vec_cuda<T, type_acc, 1>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
-                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-            break;
-        case 2:
-            launch_mul_mat_vec_cuda<T, type_acc, 2>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
-                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-            break;
-        case 3:
-            launch_mul_mat_vec_cuda<T, type_acc, 3>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
-                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-            break;
-        case 4:
-            launch_mul_mat_vec_cuda<T, type_acc, 4>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
-                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-            break;
-        case 5:
-            launch_mul_mat_vec_cuda<T, type_acc, 5>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
-                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-            break;
-        case 6:
-            launch_mul_mat_vec_cuda<T, type_acc, 6>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
-                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-            break;
-        case 7:
-            launch_mul_mat_vec_cuda<T, type_acc, 7>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
-                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-            break;
-        case 8:
-            launch_mul_mat_vec_cuda<T, type_acc, 8>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
-                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-            break;
-        default:
-            GGML_ABORT("fatal error");
-            break;
-    }
-}
-
-template<typename T>
-static void mul_mat_vec_cuda(
-        const T * x, const float * y, const int32_t * ids, float * dst,
-        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
-        const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
-        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
-        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        enum ggml_prec prec, cudaStream_t stream) {
-    if constexpr(std::is_same<T, half>::value) {
-        if (prec == GGML_PREC_DEFAULT) {
-            mul_mat_vec_cuda_switch_ncols_dst<T, half>
-                (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
-                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-            return;
-        }
-    }
-    mul_mat_vec_cuda_switch_ncols_dst<T, float>
-        (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
-         nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-         stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-}
-
-void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
-    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
-    GGML_ASSERT(         dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const size_t ts_src0 = ggml_type_size(src0->type);
-    const size_t ts_src1 = ggml_type_size(src1->type);
-    const size_t ts_dst  = ggml_type_size(dst->type);
-
-    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
-    GGML_ASSERT(ne13 == ne3);
-
-    GGML_ASSERT(        nb00       == ts_src0);
-    GGML_ASSERT(        nb10       == ts_src1);
-    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
-    GGML_ASSERT(        nb0        == ts_dst);
-
-    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
-    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
-
-    const float   * src1_d =       (const float   *) src1->data;
-    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
-    float         *  dst_d =       (float         *)  dst->data;
-
-    const int64_t s01 = src0->nb[1] / ts_src0;
-    const int64_t s11 = src1->nb[1] / ts_src1;
-    const int64_t s1  =  dst->nb[1] / ts_dst;
-    const int64_t s02 = src0->nb[2] / ts_src0;
-    const int64_t s12 = src1->nb[2] / ts_src1;
-    const int64_t s2  =  dst->nb[2] / ts_dst;
-    const int64_t s03 = src0->nb[3] / ts_src0;
-    const int64_t s13 = src1->nb[3] / ts_src1;
-    const int64_t s3  =  dst->nb[3] / ts_dst;
-
-    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
-    const int64_t ncols_dst          = ids ? ne2  : ne1;
-    const int64_t nchannels_y        = ids ? ne11 : ne12;
-    const int64_t nchannels_dst      = ids ? ne1  : ne2;
-    const int64_t stride_channel_dst = ids ? s1   : s2;
-    const int64_t stride_channel_y   = ids ? s11  : s12;
-
-    GGML_ASSERT(!ids || ncols_dst == 1);
-
-    switch (src0->type) {
-        case GGML_TYPE_F32: {
-            const float * src0_d = (const float *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
-                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
-        } break;
-        case GGML_TYPE_F16: {
-            const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
-                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
-        } break;
-        case GGML_TYPE_BF16: {
-            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
-                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
-        } break;
-        default:
-            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
-    }
-}
-
-void ggml_cuda_op_mul_mat_vec(
-    ggml_backend_cuda_context & ctx,
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
-    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, cudaStream_t stream) {
-
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne0  =  dst->ne[0];
-    const int64_t row_diff = row_high - row_low;
-
-    const int id = ggml_cuda_get_device();
-    const int cc = ggml_cuda_info().devices[id].cc;
-    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
-
-
-    // ggml_cuda_op provides single, contiguous matrices
-    const int64_t stride_row         = ne00;
-    const int64_t stride_col_y       = ne10;
-    const int64_t stride_col_dst     = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
-    const int64_t nchannels_x        = 1;
-    const int64_t nchannels_y        = 1;
-    const int64_t nchannels_dst      = 1;
-    const int64_t stride_channel_x   = 0;
-    const int64_t stride_channel_y   = 0;
-    const int64_t stride_channel_dst = 0;
-    const int64_t nsamples_x         = 1;
-    const int64_t nsamples_dst       = 1;
-    const int64_t stride_sample_x    = 0;
-    const int64_t stride_sample_y    = 0;
-    const int64_t stride_sample_dst  = 0;
-
-    switch (src0->type) {
-        case GGML_TYPE_F32: {
-            const float * src0_d = (const float *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
-        } break;
-        case GGML_TYPE_F16: {
-            const half * src0_d = (const half *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
-        } break;
-        case GGML_TYPE_BF16: {
-            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
-        } break;
-        default:
-            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
-    }
-
-    GGML_UNUSED(ctx);
-    GGML_UNUSED(src1);
-    GGML_UNUSED(dst);
-    GGML_UNUSED(src1_ddq_i);
-    GGML_UNUSED(src1_ncols);
-    GGML_UNUSED(src1_padded_row_size);
-}
-
-bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
-    if (src0_ne[0] % 2 != 0) {
-        return false;
-    }
-    switch (type) {
-        case GGML_TYPE_F32:
-            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
-                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
-                    return ne11 <= 8;
-                }
-                if (cc >= GGML_CUDA_CC_TURING) {
-                    return ne11 <= 4;
-                }
-                return ne11 <= 3;
-            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
-                if (fp32_mma_hardware_available(cc)) {
-                    return ne11 <= 3;
-                }
-                return ne11 <= 8;
-            }
-            return ne11 <= 8;
-        case GGML_TYPE_F16:
-            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
-                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
-                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
-                    return src0_small && ne11 <= 4;
-                }
-                if (fp16_mma_hardware_available(cc)) {
-                    return src0_small && ne11 <= 3;
-                }
-                return ne11 <= 8;
-            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
-                if (fp16_mma_hardware_available(cc)) {
-                    if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
-                        return ne11 <= 5;
-                    }
-                    return ne11 <= 2;
-                }
-                return ne11 <= 8;
-            }
-            return ne11 <= 8;
-        case GGML_TYPE_BF16:
-            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
-                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
-                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
-                    return src0_small && ne11 <= 4;
-                }
-                if (bf16_mma_hardware_available(cc)) {
-                    return src0_small && ne11 <= 3;
-                }
-                return ne11 <= 8;
-            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
-                if (bf16_mma_hardware_available(cc)) {
-                    return ne11 <= 3;
-                }
-                return ne11 <= 8;
-            }
-            return ne11 <= 8;
-        default:
-            return false;
-    }
-}
diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmv.cuh
deleted file mode 100644 (file)
index 1330bcb..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#include "common.cuh"
-
-void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
-
-void ggml_cuda_op_mul_mat_vec(
-    ggml_backend_cuda_context & ctx,
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
-    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, cudaStream_t stream);
-
-bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu
new file mode 100644 (file)
index 0000000..1ad4bc7
--- /dev/null
@@ -0,0 +1,510 @@
+#include "ggml.h"
+#include "common.cuh"
+#include "mmvf.cuh"
+
+template <typename T, typename type_acc, int ncols_dst, int block_size>
+static __global__ void mul_mat_vec_f(
+        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
+        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+    const int row         = blockIdx.x;
+    const int channel_dst = blockIdx.y;
+    const int channel_x   = ids ? ids[channel_dst]          : channel_dst / channel_ratio;
+    const int channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
+    const int sample_dst  = blockIdx.z;
+    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_y    = sample_dst;
+    const int tid         = threadIdx.x;
+
+    constexpr int warp_size   = ggml_cuda_get_physical_warp_size();
+
+    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
+    y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
+    dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
+
+    const float2 * y2 = (const float2 *) y;
+
+    extern __shared__ char data_mmv[];
+    float * buf_iw = (float *) data_mmv;
+
+    if (block_size > warp_size) {
+        if (tid < warp_size) {
+            buf_iw[tid] = 0.0f;
+        }
+        __syncthreads();
+    }
+
+    float sumf[ncols_dst] = {0.0f};
+
+    if constexpr (std::is_same_v<T, float>) {
+        const float2 * x2 = (const float2 *) x;
+
+        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+            const float2 tmpx = x2[col2];
+
+#pragma unroll
+            for (int j = 0; j < ncols_dst; ++j) {
+                const float2 tmpy = y2[j*stride_col_y2 + col2];
+                sumf[j] += tmpx.x*tmpy.x;
+                sumf[j] += tmpx.y*tmpy.y;
+            }
+        }
+    } else if constexpr (std::is_same_v<T, half>) {
+        const half2 * x2 = (const half2 *) x;
+
+        if (std::is_same_v<type_acc, float>) {
+            for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+                const float2 tmpx = __half22float2(x2[col2]);
+
+#pragma unroll
+                for (int j = 0; j < ncols_dst; ++j) {
+                    const float2 tmpy = y2[j*stride_col_y2 + col2];
+                    sumf[j] += tmpx.x * tmpy.x;
+                    sumf[j] += tmpx.y * tmpy.y;
+                }
+            }
+        } else {
+#ifdef FP16_AVAILABLE
+            half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
+
+            for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+                const half2 tmpx = x2[col2];
+
+#pragma unroll
+                for (int j = 0; j < ncols_dst; ++j) {
+                    const float2 tmpy = y2[j*stride_col_y2 + col2];
+                    sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
+                }
+            }
+
+#pragma unroll
+            for (int j = 0; j < ncols_dst; ++j) {
+                sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
+            }
+#else
+            NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+        }
+    } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
+        const int * x2 = (const int *) x;
+        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+            const int tmpx = x2[col2];
+#pragma unroll
+            for (int j = 0; j < ncols_dst; ++j) {
+                const float2 tmpy = y2[j*stride_col_y2 + col2];
+                sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
+                sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
+            }
+        }
+    } else {
+        static_assert(std::is_same_v<T, void>, "unsupported type");
+    }
+
+#pragma unroll
+    for (int j = 0; j < ncols_dst; ++j) {
+        sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
+
+        if (block_size > warp_size) {
+            buf_iw[tid/warp_size] = sumf[j];
+            __syncthreads();
+            if (tid < warp_size) {
+                sumf[j] = buf_iw[tid];
+                sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
+            }
+            if (j < ncols_dst) {
+                __syncthreads();
+            }
+        }
+    }
+
+    if (tid >= ncols_dst) {
+        return;
+    }
+
+    dst[tid*stride_col_dst + row] = sumf[tid];
+}
+
+template <typename T, typename type_acc, int ncols_dst>
+static void launch_mul_mat_vec_f_cuda(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    GGML_ASSERT(ncols        % 2 == 0);
+    GGML_ASSERT(stride_row   % 2 == 0);
+    GGML_ASSERT(stride_col_y % 2 == 0);
+    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const int64_t channel_ratio = nchannels_dst / nchannels_x;
+    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
+
+    const int device = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[device].warp_size;
+
+    int64_t block_size_best = warp_size;
+    int64_t niter_best      = (ncols + 2*warp_size - 1) / (2*warp_size);
+    int64_t max_block_size  = 256;
+    if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
+        max_block_size = 128;
+    }
+    for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
+        const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
+        if (niter < niter_best) {
+            niter_best      = niter;
+            block_size_best = block_size;
+        }
+    }
+
+    const int nbytes_shared = warp_size*sizeof(float);
+    const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
+    const dim3 block_dims(block_size_best, 1, 1);
+    switch (block_size_best) {
+        case   32: {
+            mul_mat_vec_f<T, type_acc, ncols_dst,  32><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case   64: {
+            mul_mat_vec_f<T, type_acc, ncols_dst,  64><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case   96: {
+            mul_mat_vec_f<T, type_acc, ncols_dst,  96><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case  128: {
+            mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case  160: {
+            mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case  192: {
+            mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case  224: {
+            mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case  256: {
+            mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+template <typename T, typename type_acc>
+static void mul_mat_vec_f_cuda_switch_ncols_dst(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    switch (ncols_dst) {
+        case 1:
+            launch_mul_mat_vec_f_cuda<T, type_acc, 1>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 2:
+            launch_mul_mat_vec_f_cuda<T, type_acc, 2>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 3:
+            launch_mul_mat_vec_f_cuda<T, type_acc, 3>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 4:
+            launch_mul_mat_vec_f_cuda<T, type_acc, 4>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 5:
+            launch_mul_mat_vec_f_cuda<T, type_acc, 5>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 6:
+            launch_mul_mat_vec_f_cuda<T, type_acc, 6>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 7:
+            launch_mul_mat_vec_f_cuda<T, type_acc, 7>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 8:
+            launch_mul_mat_vec_f_cuda<T, type_acc, 8>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+template<typename T>
+static void mul_mat_vec_f_cuda(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        enum ggml_prec prec, cudaStream_t stream) {
+    if constexpr(std::is_same_v<T, half>) {
+        if (prec == GGML_PREC_DEFAULT) {
+            mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
+                (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            return;
+        }
+    }
+    mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
+        (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+         nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+         stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+}
+
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
+    GGML_ASSERT(         dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
+
+    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
+    GGML_ASSERT(ne13 == ne3);
+
+    GGML_ASSERT(        nb00       == ts_src0);
+    GGML_ASSERT(        nb10       == ts_src1);
+    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+    GGML_ASSERT(        nb0        == ts_dst);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+    const float   * src1_d =       (const float   *) src1->data;
+    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
+    float         *  dst_d =       (float         *)  dst->data;
+
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s11 = src1->nb[1] / ts_src1;
+    const int64_t s1  =  dst->nb[1] / ts_dst;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s12 = src1->nb[2] / ts_src1;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s13 = src1->nb[3] / ts_src1;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
+
+    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+    const int64_t ncols_dst          = ids ? ne2  : ne1;
+    const int64_t nchannels_y        = ids ? ne11 : ne12;
+    const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_channel_dst = ids ? s1   : s2;
+    const int64_t stride_channel_y   = ids ? s11  : s12;
+
+    GGML_ASSERT(!ids || ncols_dst == 1);
+
+    switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0->data;
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+        } break;
+        case GGML_TYPE_F16: {
+            const half * src0_d = (const half *) src0->data;
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
+}
+
+void ggml_cuda_op_mul_mat_vec_f(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne0  =  dst->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    const int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+
+    // ggml_cuda_op provides single, contiguous matrices
+    const int64_t stride_row         = ne00;
+    const int64_t stride_col_y       = ne10;
+    const int64_t stride_col_dst     = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
+    const int64_t nchannels_x        = 1;
+    const int64_t nchannels_y        = 1;
+    const int64_t nchannels_dst      = 1;
+    const int64_t stride_channel_x   = 0;
+    const int64_t stride_channel_y   = 0;
+    const int64_t stride_channel_dst = 0;
+    const int64_t nsamples_x         = 1;
+    const int64_t nsamples_dst       = 1;
+    const int64_t stride_sample_x    = 0;
+    const int64_t stride_sample_y    = 0;
+    const int64_t stride_sample_dst  = 0;
+
+    switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0_dd_i;
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+        } break;
+        case GGML_TYPE_F16: {
+            const half * src0_d = (const half *) src0_dd_i;
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
+
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddq_i);
+    GGML_UNUSED(src1_ncols);
+    GGML_UNUSED(src1_padded_row_size);
+}
+
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
+    if (src0_ne[0] % 2 != 0) {
+        return false;
+    }
+    switch (type) {
+        case GGML_TYPE_F32:
+            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+                if (ampere_mma_available(cc)) {
+                    return ne11 <= 3;
+                }
+                if (cc >= GGML_CUDA_CC_TURING) {
+                    return ne11 <= 4;
+                }
+                return ne11 <= 3;
+            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+                if (fp32_mma_hardware_available(cc)) {
+                    return ne11 <= 3;
+                }
+                return ne11 <= 8;
+            }
+            return ne11 <= 8;
+        case GGML_TYPE_F16:
+            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+                if (ampere_mma_available(cc)) {
+                    return src0_small && ne11 == 1;
+                }
+                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+                    return src0_small && ne11 <= 4;
+                }
+                if (fp16_mma_hardware_available(cc)) {
+                    return src0_small && ne11 <= 3;
+                }
+                return ne11 <= 8;
+            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+                if (fp16_mma_hardware_available(cc)) {
+                    if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+                        return ne11 <= 5;
+                    }
+                    return ne11 <= 2;
+                }
+                return ne11 <= 8;
+            }
+            return ne11 <= 8;
+        case GGML_TYPE_BF16:
+            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+                if (ampere_mma_available(cc)) {
+                    return src0_small && ne11 == 1;
+                }
+                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+                    return src0_small && ne11 <= 4;
+                }
+                if (bf16_mma_hardware_available(cc)) {
+                    return src0_small && ne11 <= 3;
+                }
+                return ne11 <= 8;
+            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+                if (bf16_mma_hardware_available(cc)) {
+                    return ne11 <= 3;
+                }
+                return ne11 <= 8;
+            }
+            return ne11 <= 8;
+        default:
+            return false;
+    }
+}
diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh
new file mode 100644 (file)
index 0000000..1da4609
--- /dev/null
@@ -0,0 +1,11 @@
+#include "common.cuh"
+
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
+void ggml_cuda_op_mul_mat_vec_f(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);
+
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
index 8b172e60f4b7e770d62eb5fb5a7fc8006819a3c3..c31f3192322529b77be55aa1ffa36a201a5ede03 100644 (file)
 #endif
 
 typedef hip_bfloat16 nv_bfloat16;
+typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
 
 typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
 typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
index 198963202443ac2497fa2cd5077f8affd9d98ea4..8c55a2e4e56f1d2ee8fa688073eccaa58160879e 100644 (file)
 #define cudaStreamEndCapture musaStreamEndCapture
 #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
 
-typedef mt_bfloat16 nv_bfloat16;
+typedef __mt_bfloat16 nv_bfloat16;
+typedef __mt_bfloat162 nv_bfloat162;