]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: use MMQ instead of cuBLAS by default (#8075)
authorJohannes Gäßler <redacted>
Mon, 24 Jun 2024 15:43:42 +0000 (17:43 +0200)
committerGitHub <redacted>
Mon, 24 Jun 2024 15:43:42 +0000 (17:43 +0200)
CMakeLists.txt
Makefile
README.md
ggml-cuda.cu
ggml-cuda/common.cuh
ggml-cuda/mmq.cu
ggml-cuda/mmq.cuh
ggml-cuda/mmvq.cuh

index 49ba45356a78d86f1a4ada182e9604ab0ac0b49e..1acf4bb08ba1792f0a6e1c203cc161f43bc6b272 100644 (file)
@@ -102,7 +102,8 @@ option(LLAMA_LLAMAFILE                       "llama: use llamafile SGEMM"
 option(LLAMA_CUDA                            "llama: use CUDA"                                  OFF)
 option(LLAMA_CUBLAS                          "llama: use CUDA (deprecated, use LLAMA_CUDA)"     OFF)
 option(LLAMA_CUDA_FORCE_DMMV                 "llama: use dmmv instead of mmvq CUDA kernels"     OFF)
-option(LLAMA_CUDA_FORCE_MMQ                  "llama: use mmq kernels instead of cuBLAS"         OFF)
+option(LLAMA_CUDA_FORCE_MMQ                  "llama: always use mmq kernels instead of cuBLAS"  OFF)
+option(LLAMA_CUDA_FORCE_CUBLAS               "llama: always use cuBLAS instead of mmq kernels"  OFF)
 set(LLAMA_CUDA_DMMV_X      "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
 set(LLAMA_CUDA_MMV_Y        "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
 option(LLAMA_CUDA_F16                        "llama: use 16 bit floats for some calculations"   OFF)
@@ -416,13 +417,14 @@ if (LLAMA_CUDA)
 
         if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
             # 52 == lowest CUDA 12 standard
-            # 60 == f16 CUDA intrinsics
+            # 60 == FP16 CUDA intrinsics
             # 61 == integer CUDA intrinsics
-            # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
+            # 70 == FP16 tensor cores
+            # 75 == int8 tensor cores
             if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
-                set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
+                set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
             else()
-                set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
+                set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
                 #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
             endif()
         endif()
@@ -447,6 +449,9 @@ if (LLAMA_CUDA)
         if (LLAMA_CUDA_FORCE_MMQ)
             add_compile_definitions(GGML_CUDA_FORCE_MMQ)
         endif()
+        if (LLAMA_CUDA_FORCE_CUBLAS)
+            add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
+        endif()
         if (LLAMA_CUDA_NO_VMM)
             add_compile_definitions(GGML_CUDA_NO_VMM)
         endif()
index 3aad77394c5ac86b08fb136e0df5859389f50d64..f6e8eb73eb9eb328c623b526edec6d913834549e 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -537,6 +537,9 @@ endif # LLAMA_CUDA_FORCE_DMMV
 ifdef LLAMA_CUDA_FORCE_MMQ
        MK_NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ
 endif # LLAMA_CUDA_FORCE_MMQ
+ifdef LLAMA_CUDA_FORCE_CUBLAS
+       MK_NVCCFLAGS += -DGGML_CUDA_FORCE_CUBLAS
+endif # LLAMA_CUDA_FORCE_CUBLAS
 ifdef LLAMA_CUDA_DMMV_X
        MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
 else
index 40793c8eab880fda2234ec766c90970bb3fbde89..a54ee3951d41dc121a4ba118ad2af0dce2c25742 100644 (file)
--- a/README.md
+++ b/README.md
@@ -510,8 +510,9 @@ Building the program with BLAS support may lead to some performance improvements
   |--------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
   | LLAMA_CUDA_FORCE_DMMV          | Boolean                | false   | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
   | LLAMA_CUDA_DMMV_X              | Positive integer >= 32 | 32      | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants.                                         |
-  | LLAMA_CUDA_MMV_Y               | Positive integer       | 1       | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended.                                               |
-  | LLAMA_CUDA_FORCE_MMQ           | Boolean                | false   | Force the use of dequantization + matrix multiplication kernels instead of leveraging Math libraries. |                                                                                                                                         |
+  | LLAMA_CUDA_MMV_Y               | Positive integer       | 1       | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended.                                                                                                                                         |
+  | LLAMA_CUDA_FORCE_MMQ           | Boolean                | false   | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). Speed for large batch sizes will be worse but VRAM consumption will be lower.                    |
+  | LLAMA_CUDA_FORCE_CUBLAS        | Boolean                | false   | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models                                                                                                                                                                                       |
   | LLAMA_CUDA_F16                 | Boolean                | false   | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs.                                                           |
   | LLAMA_CUDA_KQUANTS_ITER        | 1 or 2                 | 2       | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs.                                                                                                                     |
   | LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer       | 128     | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial.                                                                         |
index f914efd71266572fccba61c89654aa6738378b03..2dda03924253196d6d6fa4a26af7029dd90a1fdd 100644 (file)
@@ -152,16 +152,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
     GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
 
     int64_t total_vram = 0;
-#if defined(GGML_CUDA_FORCE_MMQ)
-    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:   yes\n", __func__);
+#ifdef GGML_CUDA_FORCE_MMQ
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    yes\n", __func__);
 #else
-    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:   no\n", __func__);
-#endif
-#if defined(CUDA_USE_TENSOR_CORES)
-    GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    no\n", __func__);
+#endif // GGML_CUDA_FORCE_MMQ
+#ifdef GGML_CUDA_FORCE_CUBLAS
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
 #else
-    GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
-#endif
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
+#endif // GGML_CUDA_FORCE_CUBLAS
     GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
     for (int id = 0; id < info.device_count; ++id) {
         int device_vmm = 0;
@@ -1873,9 +1873,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
 
-    int64_t min_compute_capability = INT_MAX;
+    bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
+    bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+    bool              use_mul_mat_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+    bool any_gpus_with_slow_fp16 = false;
 
-    bool any_pascal_with_slow_fp16 = false;
     if (split) {
         ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
         auto & tensor_split = buft_ctx->tensor_split;
@@ -1885,55 +1893,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
                 continue;
             }
 
-            if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
-                min_compute_capability = ggml_cuda_info().devices[id].cc;
-            }
-            if (ggml_cuda_info().devices[id].cc == 610) {
-                any_pascal_with_slow_fp16 = true;
-            }
+            const int cc            = ggml_cuda_info().devices[id].cc;
+            use_mul_mat_vec_q       = use_mul_mat_vec_q       && cc >= MIN_CC_DP4A;
+            use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+            any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
         }
     } else {
-        min_compute_capability    = ggml_cuda_info().devices[ctx.device].cc;
-        any_pascal_with_slow_fp16 = ggml_cuda_info().devices[ctx.device].cc == 610;
+        const int cc            = ggml_cuda_info().devices[ctx.device].cc;
+        use_mul_mat_vec_q       = use_mul_mat_vec_q       && cc >= MIN_CC_DP4A;
+        use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+        any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
     }
 
-    // check data types and tensor shapes for custom matrix multiplication kernels:
-    bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
-
-    bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
-
-    bool              use_mul_mat_q =  ggml_cuda_supports_mmq(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
-
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-
-    const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
-
-#ifdef CUDA_USE_TENSOR_CORES
-    use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
-#endif // CUDA_USE_TENSOR_CORES
-
-#else
-
-    // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
-    const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
-
-    // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
-    use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
-    use_mul_mat_q     = use_mul_mat_q     && min_compute_capability >= MIN_CC_DP4A;
-
-#ifdef CUDA_USE_TENSOR_CORES
-    // when tensor cores are available, use them for large batch size
-    // ref: https://github.com/ggerganov/llama.cpp/pull/3776
-    use_mul_mat_q     = use_mul_mat_q     && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
-#endif // CUDA_USE_TENSOR_CORES
-
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-
     // if mmvq is available it's a better choice than dmmv:
 #ifndef GGML_CUDA_FORCE_DMMV
     use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
@@ -1947,21 +1918,22 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
-        // KQ single-batch
+    if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+        // FP32 precision KQ single-batch for batch size 1 without FlashAttention
         ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
-    } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
-        // KQV single-batch
+    } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+        // FP32 precision KQV single-batch for batch size 1 without FlashAttention
         ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
-        // KQ + KQV multi-batch
-        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
     } else if (use_dequantize_mul_mat_vec) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
     } else if (use_mul_mat_vec_q) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
     } else if (use_mul_mat_q) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
+    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
+               && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+        // KQ + KQV multi-batch without FlashAttention
+        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
     } else {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
     }
index 5c866253586e7a99322b2b300b51cf566025c91b..8d00db6c193ff67e4f4eee5de13c8b38b9bab536 100644 (file)
 #define CC_RDNA2      (CC_OFFSET_AMD + 1030)
 #define CC_RDNA3      (CC_OFFSET_AMD + 1100)
 
-// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
-// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
-// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
-// -  7B quantum model: +100-200 MB
-// - 13B quantum model: +200-400 MB
-//
-//#define GGML_CUDA_FORCE_MMQ
-
-// TODO: improve this to be correct for more hardware
-//       for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
-#if !defined(GGML_CUDA_FORCE_MMQ)
-#define CUDA_USE_TENSOR_CORES
-#endif
-
-#define MMVQ_MAX_BATCH_SIZE  8 // max batch size to use MMVQ kernels
-#define  MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
-
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
 #if defined(_MSC_VER)
@@ -343,15 +326,15 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
 #define INT8_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
 
-static bool fast_fp16_available(const int cc) {
+static constexpr bool fast_fp16_available(const int cc) {
     return cc >= CC_PASCAL && cc != 610;
 }
 
-static bool fp16_mma_available(const int cc) {
+static constexpr bool fp16_mma_available(const int cc) {
     return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
 }
 
-static bool int8_mma_available(const int cc) {
+static constexpr bool int8_mma_available(const int cc) {
     return cc < CC_OFFSET_AMD && cc >= CC_TURING;
 }
 
@@ -643,19 +626,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
     static constexpr int qi = QI3_S;
 };
 
-static constexpr int get_mmq_x_max_host(int cc) {
-#ifdef CUDA_USE_TENSOR_CORES
-    return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
-#else
-    return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
-#endif // CUDA_USE_TENSOR_CORES
-}
-
-// Round rows to this value for --split-mode row:
-static constexpr int get_mmq_y_host(int cc) {
-    return cc >= CC_VOLTA ? 128 : 64;
-}
-
 //////////////////////
 
 struct ggml_cuda_device_info {
index 6dbd85feff2fa2e44e74e6da985deddb8fcbe44b..0308beaccbaa3876fd508b194652ca53f15e6d1d 100644 (file)
@@ -69,7 +69,13 @@ void ggml_cuda_op_mul_mat_q(
     GGML_UNUSED(src1_ddf_i);
 }
 
-bool ggml_cuda_supports_mmq(enum ggml_type type) {
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
+#ifdef GGML_CUDA_FORCE_CUBLAS
+    return false;
+#endif // GGML_CUDA_FORCE_CUBLAS
+
+    bool mmq_supported;
+
     switch (type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
@@ -81,8 +87,32 @@ bool ggml_cuda_supports_mmq(enum ggml_type type) {
         case GGML_TYPE_Q4_K:
         case GGML_TYPE_Q5_K:
         case GGML_TYPE_Q6_K:
-            return true;
+            mmq_supported = true;
+            break;
         default:
-            return false;
+            mmq_supported = false;
+            break;
+    }
+
+    if (!mmq_supported) {
+        return false;
+    }
+
+    if (int8_mma_available(cc)) {
+        return true;
+    }
+
+    if (cc < MIN_CC_DP4A) {
+        return false;
     }
+
+#ifdef GGML_CUDA_FORCE_MMQ
+    return true;
+#endif //GGML_CUDA_FORCE_MMQ
+
+    if (cc < CC_OFFSET_AMD) {
+        return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    }
+
+    return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
index 0f7f8ae51e425c65c69b88395bff5f868e53597d..1fc948be5bbe838878a37b5fc767b1411a06e434 100644 (file)
@@ -7,6 +7,8 @@
 #include <climits>
 #include <cstdint>
 
+#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
+
 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
 typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k0);
 typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
@@ -24,25 +26,42 @@ struct tile_x_sizes {
     int sc;
 };
 
-// get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
+static constexpr int get_mmq_x_max_host(const int cc) {
+    return int8_mma_available(cc) ? 128 :
+#ifdef GGML_CUDA_FORCE_MMQ
+        cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128                     : 64;
+#else
+        cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
+#endif // GGML_CUDA_FORCE_MMQ
+}
 
 static constexpr __device__ int get_mmq_x_max_device() {
+#ifdef INT8_MMA_AVAILABLE
+    return 128;
+#else // INT8_MMA_AVAILABLE
+
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-    return 64;
-#else
+    return 128;
+#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
 #if __CUDA_ARCH__ >= CC_VOLTA
-#ifdef CUDA_USE_TENSOR_CORES
-    return MMQ_MAX_BATCH_SIZE;
-#else
+#ifdef GGML_CUDA_FORCE_MMQ
+    return MMQ_DP4A_MAX_BATCH_SIZE;
+#else // GGML_CUDA_FORCE_MMQ
     return 128;
-#endif // CUDA_USE_TENSOR_CORES
-#else
+#endif // GGML_CUDA_FORCE_MMQ
+#else // __CUDA_ARCH__ >= CC_VOLTA
+
     return 64;
 #endif // __CUDA_ARCH__ >= CC_VOLTA
+
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#endif // INT8_MMA_AVAILABLE
 }
 
-// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
+static constexpr int get_mmq_y_host(const int cc) {
+    return int8_mma_available(cc) || cc >= CC_VOLTA ? 128 : 64;
+}
 
 static constexpr __device__ int get_mmq_y_device() {
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
@@ -2590,4 +2609,4 @@ void ggml_cuda_op_mul_mat_q(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream);
 
-bool ggml_cuda_supports_mmq(enum ggml_type type);
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
index 88c42c4b7a8fbc35bf49f778a9025a64d6f7050a..d9e42fdd6d16c01b2594ffc9e6012aa0e98f37a5 100644 (file)
@@ -1,5 +1,7 @@
 #include "common.cuh"
 
+#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
+
 void ggml_cuda_op_mul_mat_vec_q(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,