]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: refactor mmq, dmmv, mmvq (llama/7716)
authorJohannes Gäßler <redacted>
Wed, 5 Jun 2024 14:53:00 +0000 (16:53 +0200)
committerGeorgi Gerganov <redacted>
Sun, 16 Jun 2024 15:19:48 +0000 (18:19 +0300)
* CUDA: refactor mmq, dmmv, mmvq

* fix out-of-bounds write

* struct for qk, qr, qi

* fix cmake build

* mmq_type_traits

110 files changed:
ggml-common.h
ggml-cuda.cu
ggml-cuda/common.cuh
ggml-cuda/dmmv.cu
ggml-cuda/mmq.cu
ggml-cuda/mmq.cuh
ggml-cuda/mmvq.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu
ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu
ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu
ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu
ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu
ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu
ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu
ggml-cuda/template-instances/generate_cu_files.py
ggml-cuda/template-instances/mmq-instance-q2_k.cu [new file with mode: 0644]
ggml-cuda/template-instances/mmq-instance-q3_k.cu [new file with mode: 0644]
ggml-cuda/template-instances/mmq-instance-q4_0.cu [new file with mode: 0644]
ggml-cuda/template-instances/mmq-instance-q4_1.cu [new file with mode: 0644]
ggml-cuda/template-instances/mmq-instance-q4_k.cu [new file with mode: 0644]
ggml-cuda/template-instances/mmq-instance-q5_0.cu [new file with mode: 0644]
ggml-cuda/template-instances/mmq-instance-q5_1.cu [new file with mode: 0644]
ggml-cuda/template-instances/mmq-instance-q5_k.cu [new file with mode: 0644]
ggml-cuda/template-instances/mmq-instance-q6_k.cu [new file with mode: 0644]
ggml-cuda/template-instances/mmq-instance-q8_0.cu [new file with mode: 0644]
ggml-cuda/vecdotq.cuh

index 77e6bfba4b11b5416e2a3d0f822885dabb409a29..e8efceb760d409d7c1fe3ea3b4c27e07d0584e8f 100644 (file)
@@ -123,12 +123,18 @@ typedef sycl::half2 ggml_half2;
 #define QI1_S (QK_K / (4*QR1_S))
 #define QR1_S 8
 
+#define QI1_M (QK_K / (4*QR1_M))
+#define QR1_M 8
+
 #define QI4_NL (QK4_NL / (4*QR4_NL))
 #define QR4_NL 2
 
 #define QI4_XS (QK_K / (4*QR4_XS))
 #define QR4_XS 8
 
+#define QI3_S (QK_K / (4*QR3_S))
+#define QR3_S 8
+
 #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
 
 #define QK4_0 32
index c81c6a0d783be118cdea5bdb64d6eb17fa939a33..dad8a9e2dafe78c3c3e44feaa371c6f27e548566 100644 (file)
@@ -633,88 +633,22 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
 
 // cuda split buffer
 
-static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
-    int64_t min_compute_capability = INT_MAX;
-    int64_t max_compute_capability = INT_MIN;
+static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
+    int64_t row_rounding = 0;
     for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        if (tensor_split[id] < (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
-            if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
-                min_compute_capability = ggml_cuda_info().devices[id].cc;
-            }
-            if (max_compute_capability < ggml_cuda_info().devices[id].cc) {
-                max_compute_capability = ggml_cuda_info().devices[id].cc;
-            }
+        if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
+            continue;
         }
-    }
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-    switch(type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-            return max_compute_capability >= CC_RDNA2 ? 128 : 64;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_F32:
-            return 1;
-        case GGML_TYPE_Q2_K:
-            return max_compute_capability >= CC_RDNA2 ? 128 : 32;
-        case GGML_TYPE_Q3_K:
-            return min_compute_capability < CC_RDNA2 ? 128 : 64;
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-            return max_compute_capability >= CC_RDNA2 ? 128 : 64;
-        default:
-            GGML_ASSERT(false);
-    }
-#else
-    switch(type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-            return max_compute_capability >= CC_VOLTA ? 128 : 64;
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-            return 64;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_F32:
-            return 1;
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-            return max_compute_capability >= CC_VOLTA ? 128 : 64;
-        case GGML_TYPE_Q6_K:
-            return 64;
-        default:
-            GGML_ASSERT(false);
+        const int cc = ggml_cuda_info().devices[id].cc;
+        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
     }
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    return row_rounding;
 }
 
 static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
     const int64_t nrows = ggml_nrows(tensor);
-    const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
+    const int64_t rounding = get_row_rounding(tensor_split);
 
     *row_low = id == 0 ? 0 : nrows*tensor_split[id];
     *row_low -= *row_low % rounding;
@@ -1499,7 +1433,7 @@ static void ggml_cuda_op_mul_mat(
         // for multi GPU, get the row boundaries from tensor split
         // and round to mul_mat_q tile sizes
         if (split) {
-            const int64_t rounding = get_row_rounding(src0->type, tensor_split);
+            const int64_t rounding = get_row_rounding(tensor_split);
 
             if (id != 0) {
                 dev[id].row_low  = ne01*tensor_split[id];
index 22872ca5c1d81b3e976c2960edb828ba26c98b08..90a0a81ead789679a148811acbfbf40201b4e932 100644 (file)
 #endif
 
 #define MMVQ_MAX_BATCH_SIZE  8 // max batch size to use MMVQ kernels
-#define  MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
+#define  MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
 
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
@@ -484,6 +484,161 @@ static __device__ __forceinline__ float get_alibi_slope(
     return powf(base, exph);
 }
 
+template <ggml_type type>
+struct ggml_cuda_type_traits;
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_F16> {
+    static constexpr int qk = 1;
+    static constexpr int qr = 1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
+    static constexpr int qk = QK4_0;
+    static constexpr int qr = QR4_0;
+    static constexpr int qi = QI4_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
+    static constexpr int qk = QK4_1;
+    static constexpr int qr = QR4_1;
+    static constexpr int qi = QI4_1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
+    static constexpr int qk = QK5_0;
+    static constexpr int qr = QR5_0;
+    static constexpr int qi = QI5_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
+    static constexpr int qk = QK5_1;
+    static constexpr int qr = QR5_1;
+    static constexpr int qi = QI5_1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
+    static constexpr int qk = QK8_0;
+    static constexpr int qr = QR8_0;
+    static constexpr int qi = QI8_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_K;
+    static constexpr int qi = QI2_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_K;
+    static constexpr int qi = QI3_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR4_K;
+    static constexpr int qi = QI4_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR5_K;
+    static constexpr int qi = QI5_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR6_K;
+    static constexpr int qi = QI6_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_XXS;
+    static constexpr int qi = QI2_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_XS;
+    static constexpr int qi = QI2_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_S;
+    static constexpr int qi = QI2_S;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_XXS;
+    static constexpr int qi = QI3_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR1_S;
+    static constexpr int qi = QI1_S;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR1_M;
+    static constexpr int qi = QI1_M;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
+    static constexpr int qk = QK4_NL;
+    static constexpr int qr = QR4_NL;
+    static constexpr int qi = QI4_NL;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR4_XS;
+    static constexpr int qi = QI4_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_S;
+    static constexpr int qi = QI3_S;
+};
+
+static int get_mmq_x_max_host(const int cc) {
+#ifdef CUDA_USE_TENSOR_CORES
+    return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
+#else
+    return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
+#endif // CUDA_USE_TENSOR_CORES
+}
+
+// Round rows to this value for --split-mode row:
+static int get_mmq_y_host(const int cc, const int mmq_x) {
+    return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
+}
+
 //////////////////////
 
 struct ggml_cuda_device_info {
index 47d4d5d9e91da5b14cdfffd896a3e5a07c014e6a..174489e0665d38636b9140b23398a13464b779e0 100644 (file)
@@ -422,10 +422,22 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int
     v.y = x[ib + iqs + 1];
 }
 
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
+        type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
+        type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
+        type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
+        type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
+        type == GGML_TYPE_F16 ? convert_f16 :
+        nullptr;
+}
+
+template <ggml_type type>
 static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
-    // qk = quantized weights per x block
-    // qr = number of quantized weights per data value in x block
+    constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
+    constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
+    constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
+
     const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
 
     if (row >= nrows) {
@@ -493,7 +505,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
     // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
+    dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
@@ -502,7 +514,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
+    dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
@@ -511,7 +523,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
+    dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
@@ -520,7 +532,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
+    dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
@@ -529,7 +541,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
+    dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
@@ -580,7 +592,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<1, 1, convert_f16>
+    dequantize_mul_mat_vec<GGML_TYPE_F16>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
index ebe1dc5c8bb22bd22a98b1429286bcd1d9915ba2..58799e4caf6f8b6493134de36a0968634725eec3 100644 (file)
 #include "mmq.cuh"
-#include "vecdotq.cuh"
-
-typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
-typedef void (*load_tiles_cuda_t)(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row);
-typedef float (*vec_dot_q_mul_mat_cuda_t)(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
-typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
-typedef void (mul_mat_q_t)(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst);
-
-struct mmq_arch_config_t {
-    int x;
-    int y;
-    int nwarps;
-};
-
-struct mmq_config_t {
-    mmq_arch_config_t rdna2;
-    mmq_arch_config_t rdna1;
-    mmq_arch_config_t ampere;
-    mmq_arch_config_t pascal;
-};
-
-constexpr mmq_config_t MMQ_CONFIG_Q4_0 = {
-//        x    y  nwarps
-        { 64, 128, 8},
-        { 64,  64, 8},
-#ifdef CUDA_USE_TENSOR_CORES
-        {  4,  32, 4},
-#else
-        { 64, 128, 4},
-#endif // CUDA_USE_TENSOR_CORES
-        { 64,  64, 8},
-};
-constexpr mmq_config_t MMQ_CONFIG_Q4_1 = {
-//        x    y  nwarps
-        { 64, 128, 8},
-        { 64,  64, 8},
-#ifdef CUDA_USE_TENSOR_CORES
-        {  4,  32, 4},
-#else
-        { 64, 128, 4},
-#endif // CUDA_USE_TENSOR_CORES
-        { 64,  64, 8},
-};
-constexpr mmq_config_t MMQ_CONFIG_Q5_0 = {
-//        x    y  nwarps
-        { 64, 128, 8},
-        { 64,  64, 8},
-#ifdef CUDA_USE_TENSOR_CORES
-        {  4,  32, 4},
-#else
-        {128,  64, 4},
-#endif // CUDA_USE_TENSOR_CORES
-        { 64,  64, 8},
-};
-constexpr mmq_config_t MMQ_CONFIG_Q5_1 = {
-//        x    y  nwarps
-        { 64, 128, 8},
-        { 64,  64, 8},
-#ifdef CUDA_USE_TENSOR_CORES
-        {  4,  32, 4},
-#else
-        {128,  64, 4},
-#endif // CUDA_USE_TENSOR_CORES
-        { 64,  64, 8},
-};
-constexpr mmq_config_t MMQ_CONFIG_Q8_0 = {
-//        x    y  nwarps
-        { 64, 128, 8},
-        { 64,  64, 8},
-#ifdef CUDA_USE_TENSOR_CORES
-        {  4,  32, 4},
-#else
-        {128,  64, 4},
-#endif // CUDA_USE_TENSOR_CORES
-        { 64,  64, 8},
-};
-constexpr mmq_config_t MMQ_CONFIG_Q2_K = {
-//        x    y  nwarps
-        { 64, 128, 8},
-        {128,  32, 8},
-#ifdef CUDA_USE_TENSOR_CORES
-        {  4,  32, 4},
-#else
-        { 64, 128, 4},
-#endif // CUDA_USE_TENSOR_CORES
-        { 64,  64, 8},
-};
-constexpr mmq_config_t MMQ_CONFIG_Q3_K = {
-//        x    y  nwarps
-        {128,  64, 8},
-        { 32, 128, 8},
-#ifdef CUDA_USE_TENSOR_CORES
-        {  4,  32, 4},
-#else
-        {128, 128, 4},
-#endif // CUDA_USE_TENSOR_CORES
-        { 64,  64, 8},
-};
-constexpr mmq_config_t MMQ_CONFIG_Q4_K = {
-//        x    y  nwarps
-        { 64, 128, 8},
-        { 32,  64, 8},
-#ifdef CUDA_USE_TENSOR_CORES
-        {  4,  32, 4},
-#else
-        { 64, 128, 4},
-#endif // CUDA_USE_TENSOR_CORES
-        { 64,  64, 8},
-};
-constexpr mmq_config_t MMQ_CONFIG_Q5_K = {
-//        x    y  nwarps
-        { 64, 128, 8},
-        { 32,  64, 8},
-#ifdef CUDA_USE_TENSOR_CORES
-        {  4,  32, 4},
-#else
-        { 64, 128, 4},
-#endif // CUDA_USE_TENSOR_CORES
-        { 64,  64, 8},
-};
-constexpr mmq_config_t MMQ_CONFIG_Q6_K = {
-//        x    y  nwarps
-        { 64, 128, 8},
-        { 32,  64, 8},
-#ifdef CUDA_USE_TENSOR_CORES
-        {  4,  32, 4},
-#else
-        { 64,  64, 4},
-#endif // CUDA_USE_TENSOR_CORES
-        { 64,  64, 8},
-};
-
-// ------------------------------------------------------------
-
-template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
-    GGML_UNUSED(x_qh);
-    GGML_UNUSED(x_sc);
-
-    __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y];
-    __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
-
-    *x_ql = tile_x_qs;
-    *x_dm = (half2 *) tile_x_d;
-}
-
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-    GGML_CUDA_ASSUME(i_offset >= 0);
-    GGML_CUDA_ASSUME(i_offset <  nwarps);
-    GGML_CUDA_ASSUME(k >= 0);
-    GGML_CUDA_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI4_0;
-    const int kqsx = k % QI4_0;
-
-    const block_q4_0 * bx0 = (const block_q4_0 *) vx;
-
-    float * x_dmf = (float *) x_dm;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
-        // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
-    const int kbxd = k % blocks_per_tile_x_row;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
-        int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
-    }
-}
-
-static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
-    const float * x_dmf = (const float *) x_dm;
-
-    int u[2*VDR_Q4_0_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
-    }
-
-    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
-        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
-         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
-}
-
-template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    __shared__ int   tile_x_qs[mmq_y * (WARP_SIZE) +     + mmq_y];
-    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
-
-    *x_ql = tile_x_qs;
-    *x_dm = tile_x_dm;
-}
-
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    GGML_CUDA_ASSUME(i_offset >= 0);
-    GGML_CUDA_ASSUME(i_offset <  nwarps);
-    GGML_CUDA_ASSUME(k >= 0);
-    GGML_CUDA_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI4_1;
-    const int kqsx = k % QI4_1;
-
-    const block_q4_1 * bx0 = (const block_q4_1 *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
-    const int kbxd = k % blocks_per_tile_x_row;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
-        int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
-    }
-}
-
-static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
-
-    int u[2*VDR_Q4_1_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
-    }
-
-    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
-        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
-         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
-}
-
-template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    __shared__ int  tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
-    __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
-
-    *x_ql = tile_x_ql;
-    *x_dm = (half2 *) tile_x_d;
-}
-
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    GGML_CUDA_ASSUME(i_offset >= 0);
-    GGML_CUDA_ASSUME(i_offset <  nwarps);
-    GGML_CUDA_ASSUME(k >= 0);
-    GGML_CUDA_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI5_0;
-    const int kqsx = k % QI5_0;
-
-    const block_q5_0 * bx0 = (const block_q5_0 *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
-
-        const int ql = get_int_from_uint8(bxi->qs, kqsx);
-        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
-
-        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
-        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
-        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
-        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
-        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
-        qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
-
-        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
-
-        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
-        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
-        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
-        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
-        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
-        qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
-
-        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
-    const int kbxd = k % blocks_per_tile_x_row;
-    float * x_dmf = (float *) x_dm;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
-        int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
-    }
-}
-
-static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
-    const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
-    const float * x_dmf = (const float *) x_dm;
-    const float * y_df  = (const float *) y_ds;
-
-    int u[2*VDR_Q5_0_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
-    }
-
-    return vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
-        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
-}
-
-
-template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
-    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
-
-    *x_ql = tile_x_ql;
-    *x_dm = tile_x_dm;
-}
-
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    GGML_CUDA_ASSUME(i_offset >= 0);
-    GGML_CUDA_ASSUME(i_offset < nwarps);
-    GGML_CUDA_ASSUME(k >= 0);
-    GGML_CUDA_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI5_1;
-    const int kqsx = k % QI5_1;
-
-    const block_q5_1 * bx0 = (const block_q5_1 *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
-
-        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
-        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
-
-        int qs0 = (ql >>  0) & 0x0F0F0F0F;
-        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
-        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
-        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
-        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
-
-        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
-
-        int qs1 = (ql >>  4) & 0x0F0F0F0F;
-        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
-        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
-        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
-        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
-
-        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
-    const int kbxd = k % blocks_per_tile_x_row;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
-        int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
-    }
-}
-
-static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
-    const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
-
-    int u[2*VDR_Q5_1_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
-    }
-
-    return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
-        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
-}
-
-template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y];
-    __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
-
-    *x_ql = tile_x_qs;
-    *x_dm = (half2 *) tile_x_d;
-}
-
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    GGML_CUDA_ASSUME(i_offset >= 0);
-    GGML_CUDA_ASSUME(i_offset <  nwarps);
-    GGML_CUDA_ASSUME(k >= 0);
-    GGML_CUDA_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI8_0;
-    const int kqsx = k % QI8_0;
-    float * x_dmf = (float *) x_dm;
-
-    const block_q8_0 * bx0 = (const block_q8_0 *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
-    const int kbxd = k % blocks_per_tile_x_row;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
-        int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
-    }
-}
-
-static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
-
-    const float * x_dmf = (const float *) x_dm;
-    const float * y_df  = (const float *) y_ds;
-
-    return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
-        (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
-         y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
-}
-
-template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
-    GGML_UNUSED(x_qh);
-
-    __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
-    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
-    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/4)     + mmq_y/4];
-
-    *x_ql = tile_x_ql;
-    *x_dm = tile_x_dm;
-    *x_sc = tile_x_sc;
-}
-
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-    GGML_UNUSED(x_qh);
-
-    GGML_CUDA_ASSUME(i_offset >= 0);
-    GGML_CUDA_ASSUME(i_offset <  nwarps);
-    GGML_CUDA_ASSUME(k >= 0);
-    GGML_CUDA_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI2_K;
-    const int kqsx = k % QI2_K;
-
-    const block_q2_K * bx0 = (const block_q2_K *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
-    const int kbxd = k % blocks_per_tile_x_row;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
-        int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
-        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
-
-        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
-    }
-}
-
-static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-    GGML_UNUSED(x_qh);
-
-    const int kbx = k / QI2_K;
-    const int ky  = (k % QI2_K) * QR2_K;
-    const float * y_df = (const float *) y_ds;
-
-    int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
-
-    const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
-    const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
-
-#pragma unroll
-    for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
-        v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
-    }
-
-    const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
-
-    const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
-    return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
-}
-
-template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
-
-    __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
-    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K];
-    __shared__ int   tile_x_qh[mmq_y * (WARP_SIZE/2)     + mmq_y/2];
-    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/4)     + mmq_y/4];
-
-    *x_ql = tile_x_ql;
-    *x_dm = tile_x_dm;
-    *x_qh = tile_x_qh;
-    *x_sc = tile_x_sc;
-}
-
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-
-    GGML_CUDA_ASSUME(i_offset >= 0);
-    GGML_CUDA_ASSUME(i_offset <  nwarps);
-    GGML_CUDA_ASSUME(k >= 0);
-    GGML_CUDA_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI3_K;
-    const int kqsx = k % QI3_K;
-
-    const block_q3_K * bx0 = (const block_q3_K *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
-    const int kbxd = k % blocks_per_tile_x_row;
-    float * x_dmf = (float *) x_dm;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
-        int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
-        int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
-
-        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
-        x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
-        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
-
-        const int ksc = k % (QI3_K/4);
-
-        const int ksc_low = ksc % (QI3_K/8);
-        const int shift_low = 4 * (ksc / (QI3_K/8));
-        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
-
-        const int ksc_high = QI3_K/8;
-        const int shift_high = 2 * ksc;
-        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
-
-        const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
-
-        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
-    }
-}
-
-static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-
-    const int kbx  = k / QI3_K;
-    const int ky  = (k % QI3_K) * QR3_K;
-    const float * x_dmf = (const float *) x_dm;
-    const float * y_df  = (const float *) y_ds;
-
-    const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
-
-    int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
-
-#pragma unroll
-    for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
-        const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
-        const int shift = 2 * ((ky % 32) / 8);
-        const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
-
-        const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
-        const int vlh = (vh << 2) & 0x04040404;
-
-        v[l] = __vsubss4(vll, vlh);
-    }
-
-    const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
-    return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
-}
-
-template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
-    GGML_UNUSED(x_qh);
-
-    __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
-    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
-    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/8)     + mmq_y/8];
-
-    *x_ql = tile_x_ql;
-    *x_dm = tile_x_dm;
-    *x_sc = tile_x_sc;
-}
-
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-    GGML_UNUSED(x_qh);
-
-    GGML_CUDA_ASSUME(i_offset >= 0);
-    GGML_CUDA_ASSUME(i_offset <  nwarps);
-    GGML_CUDA_ASSUME(k >= 0);
-    GGML_CUDA_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI4_K; // == 0 if QK_K == 256
-    const int kqsx = k % QI4_K; // == k if QK_K == 256
-
-    const block_q4_K * bx0 = (const block_q4_K *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
-
-        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
-    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
-        int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
-        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
-
-        const int * scales = (const int *) bxi->scales;
-
-        const int ksc = k % (WARP_SIZE/8);
-
-        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
-        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
-        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
-
-        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
-    }
-}
-
-static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-    GGML_UNUSED(x_qh);
-
-    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
-
-    const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
-    return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
-                                      x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
-}
-
-template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
-    GGML_UNUSED(x_qh);
-
-    __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
-    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
-    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/8)     + mmq_y/8];
-
-    *x_ql = tile_x_ql;
-    *x_dm = tile_x_dm;
-    *x_sc = tile_x_sc;
-}
-
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-    GGML_UNUSED(x_qh);
-
-    GGML_CUDA_ASSUME(i_offset >= 0);
-    GGML_CUDA_ASSUME(i_offset <  nwarps);
-    GGML_CUDA_ASSUME(k >= 0);
-    GGML_CUDA_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI5_K; // == 0 if QK_K == 256
-    const int kqsx = k % QI5_K; // == k if QK_K == 256
-
-    const block_q5_K * bx0 = (const block_q5_K *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
-        const int ky = QR5_K*kqsx;
-
-        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
-        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
-        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
-
-        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
-        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
-        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
-
-        const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
-        const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
-
-        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
-        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
-    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
-        int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
-        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
-
-        const int * scales = (const int *) bxi->scales;
-
-        const int ksc = k % (WARP_SIZE/8);
-
-        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
-        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
-        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
-
-        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
-    }
-}
-
-static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-    GGML_UNUSED(x_qh);
-
-    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
-
-    const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k;
-    const int index_y = j * WARP_SIZE             + (QR5_K*k) % WARP_SIZE;
-    return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
-                                      x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
-}
-
-template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
-    GGML_UNUSED(x_qh);
-
-    __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
-    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
-    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/8)     + mmq_y/8];
-
-    *x_ql = tile_x_ql;
-    *x_dm = tile_x_dm;
-    *x_sc = tile_x_sc;
-}
-
-template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
-    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
-    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
-    GGML_UNUSED(x_qh);
-
-    GGML_CUDA_ASSUME(i_offset >= 0);
-    GGML_CUDA_ASSUME(i_offset <  nwarps);
-    GGML_CUDA_ASSUME(k >= 0);
-    GGML_CUDA_ASSUME(k <  WARP_SIZE);
-
-    const int kbx  = k / QI6_K; // == 0 if QK_K == 256
-    const int kqsx = k % QI6_K; // == k if QK_K == 256
-
-    const block_q6_K * bx0 = (const block_q6_K *) vx;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + i_offset;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
-        const int ky = QR6_K*kqsx;
-
-        const int ql = get_int_from_uint8(bxi->ql, kqsx);
-        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
-        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
-
-        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
-        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
-        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
-
-        const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
-        const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
-
-        x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
-        x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
-    }
-
-    const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
-    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
-    float * x_dmf = (float *) x_dm;
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
-        int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
-
-        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
-    }
-
-#pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
-        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
-
-        const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
-
-        x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
-    }
-}
-
-static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
-    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-    GGML_UNUSED(x_qh);
-
-    const float * x_dmf = (const float *) x_dm;
-    const float * y_df  = (const float *) y_ds;
-
-    const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
-
-    const int index_x = i * (QR6_K*WARP_SIZE + 1) +  QR6_K*k;
-    const int index_y = j * WARP_SIZE             + (QR6_K*k) % WARP_SIZE;
-    return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
-}
-
-template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
-              allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
-static __device__ __forceinline__ void mul_mat_q(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-    const block_q_t  * x = (const block_q_t  *) vx;
-    const block_q8_1 * y = (const block_q8_1 *) vy;
-
-    const int blocks_per_row_x = ncols_x / qk;
-    const int blocks_per_col_y = nrows_y / QK8_1;
-    const int blocks_per_warp = WARP_SIZE / qi;
-
-    const int & ncols_dst = ncols_y;
-
-    const int row_dst_0 = blockIdx.x*mmq_y;
-    const int & row_x_0 = row_dst_0;
-
-    const int col_dst_0 = blockIdx.y*mmq_x;
-    const int & col_y_0 = col_dst_0;
-
-    int   * tile_x_ql = nullptr;
-    half2 * tile_x_dm = nullptr;
-    int   * tile_x_qh = nullptr;
-    int   * tile_x_sc = nullptr;
-
-    allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
-
-    __shared__ int    tile_y_qs[mmq_x * WARP_SIZE];
-    __shared__ half2  tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
-
-    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
-
-    for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
-
-        load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
-                   threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
-
-#pragma unroll
-        for (int ir = 0; ir < qr; ++ir) {
-            const int kqs = ir*WARP_SIZE + threadIdx.x;
-            const int kbxd = kqs / QI8_1;
-
-#pragma unroll
-            for (int i = 0; i < mmq_x; i += nwarps) {
-                const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
-
-                const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
-
-                const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
-                tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
-            }
-
-#pragma unroll
-            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
-                const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
-                const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
-                const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
-
-                // if the sum is not needed it's faster to transform the scale to f32 ahead of time
-                const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
-                half2       * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
-                if (need_sum) {
-                    *dsi_dst = *dsi_src;
-                } else {
-                    float * dfi_dst = (float *) dsi_dst;
-                    *dfi_dst = __low2float(*dsi_src);
-                }
-            }
-
-            __syncthreads();
-
-// #pragma unroll // unrolling this loop causes too much register pressure
-            for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
-#pragma unroll
-                for (int j = 0; j < mmq_x; j += nwarps) {
-#pragma unroll
-                    for (int i = 0; i < mmq_y; i += WARP_SIZE) {
-                        sum[i/WARP_SIZE][j/nwarps] += vec_dot(
-                            tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
-                            threadIdx.x + i, threadIdx.y + j, k);
-                    }
-                }
-            }
-
-            __syncthreads();
-        }
-    }
-
-#pragma unroll
-    for (int j = 0; j < mmq_x; j += nwarps) {
-        const int col_dst = col_dst_0 + j + threadIdx.y;
-
-        if (col_dst >= ncols_dst) {
-            return;
-        }
-
-#pragma unroll
-        for (int i = 0; i < mmq_y; i += WARP_SIZE) {
-            const int row_dst = row_dst_0 + threadIdx.x + i;
-
-            if (row_dst >= nrows_dst) {
-                continue;
-            }
-
-            dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
-        }
-    }
-}
-
-static constexpr __device__ mmq_arch_config_t get_arch_config_device(mmq_config_t mmq_config) {
-
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-
-#if defined(RDNA3) || defined(RDNA2)
-    return mmq_config.rdna2;
-#else
-    return mmq_config.rdna1;
-#endif // defined(RDNA3) || defined(RDNA2)
-
-#else
-
-#if __CUDA_ARCH__ >= CC_VOLTA
-    return mmq_config.ampere;
-#else
-    return mmq_config.pascal;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
-
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-}
-
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_0.rdna2.nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-    mul_mat_q4_0(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
-    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_0);
-
-    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q4_0<arch_config.y>,
-        load_tiles_q4_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-#else
-    GGML_UNUSED(get_arch_config_device);
-    GGML_UNUSED(vec_dot_q4_0_q8_1_mul_mat);
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
-
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_1.rdna2.nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#elif __CUDA_ARCH__ < CC_VOLTA
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_1.pascal.nwarps, 2)
-#endif // __CUDA_ARCH__ < CC_VOLTA
-    mul_mat_q4_1(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
-    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_1);
-
-    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q4_1<arch_config.y>,
-        load_tiles_q4_1<arch_config.y, arch_config.nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-#else
-    GGML_UNUSED(get_arch_config_device);
-    GGML_UNUSED(vec_dot_q4_1_q8_1_mul_mat);
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
-
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_0.rdna2.nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-    mul_mat_q5_0(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
-    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_0);
-
-    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q5_0<arch_config.y>,
-        load_tiles_q5_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-#else
-    GGML_UNUSED(get_arch_config_device);
-    GGML_UNUSED(vec_dot_q5_0_q8_1_mul_mat);
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
-
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_1.rdna2.nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-mul_mat_q5_1(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
-    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_1);
-
-    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q5_1<arch_config.y>,
-        load_tiles_q5_1<arch_config.y, arch_config.nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-#else
-    GGML_UNUSED(get_arch_config_device);
-    GGML_UNUSED(vec_dot_q5_1_q8_1_mul_mat);
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
-
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q8_0.rdna2.nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-    mul_mat_q8_0(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
-    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q8_0);
-
-    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q8_0<arch_config.y>,
-        load_tiles_q8_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-#else
-    GGML_UNUSED(get_arch_config_device);
-    GGML_UNUSED(vec_dot_q8_0_q8_1_mul_mat);
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
-
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q2_K.rdna2.nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-mul_mat_q2_K(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
-    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q2_K);
-
-    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q2_K<arch_config.y>,
-        load_tiles_q2_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-#else
-    GGML_UNUSED(get_arch_config_device);
-    GGML_UNUSED(vec_dot_q2_K_q8_1_mul_mat);
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
-
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q3_K.rdna2.nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#elif __CUDA_ARCH__ < CC_VOLTA
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q3_K.pascal.nwarps, 2)
-#endif // __CUDA_ARCH__ < CC_VOLTA
-    mul_mat_q3_K(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
-    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q3_K);
-
-    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q3_K<arch_config.y>,
-        load_tiles_q3_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-#else
-    GGML_UNUSED(get_arch_config_device);
-    GGML_UNUSED(vec_dot_q3_K_q8_1_mul_mat);
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
-
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.rdna2.nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#elif __CUDA_ARCH__ < CC_VOLTA
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.pascal.nwarps, 2)
-#endif // __CUDA_ARCH__ < CC_VOLTA
-    mul_mat_q4_K(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
-    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_K);
-
-    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q4_K<arch_config.y>,
-        load_tiles_q4_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-#else
-    GGML_UNUSED(get_arch_config_device);
-    GGML_UNUSED(vec_dot_q4_K_q8_1_mul_mat);
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
-
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_K.rdna2.nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-mul_mat_q5_K(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
-    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_K);
-
-    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q5_K<arch_config.y>,
-        load_tiles_q5_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-#else
-    GGML_UNUSED(get_arch_config_device);
-    GGML_UNUSED(vec_dot_q5_K_q8_1_mul_mat);
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
-
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q6_K.rdna2.nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#elif __CUDA_ARCH__ < CC_VOLTA
-    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.pascal.nwarps, 2)
-#endif // __CUDA_ARCH__ < CC_VOLTA
-    mul_mat_q6_K(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
-    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q6_K);
-
-    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q6_K<arch_config.y>,
-        load_tiles_q6_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-#else
-    GGML_UNUSED(get_arch_config_device);
-    GGML_UNUSED(vec_dot_q6_K_q8_1_mul_mat);
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-}
-
-#define MMQ_SWITCH_CASE(type_suffix)                                                                        \
-    case GGML_TYPE_Q##type_suffix: if (row_diff % arch_config.y == 0) {                                     \
-        const bool need_check = false;                                                                      \
-        mul_mat_q##type_suffix<need_check><<<block_nums, block_dims, 0, stream>>>                           \
-            (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst); \
-    } else {                                                                                                \
-        const bool need_check = true;                                                                       \
-        mul_mat_q##type_suffix<need_check><<<block_nums, block_dims, 0, stream>>>                           \
-            (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst); \
-    } break;                                                                                                \
 
 void ggml_cuda_op_mul_mat_q(
     ggml_backend_cuda_context & ctx,
@@ -1454,12 +8,15 @@ void ggml_cuda_op_mul_mat_q(
 
     const int64_t ne00 = src0->ne[0];
 
+    const int64_t nb01 = src0->nb[1];
+
     const int64_t ne10 = src1->ne[0];
     GGML_ASSERT(ne10 % QK8_1 == 0);
 
     const int64_t ne0 = dst->ne[0];
 
     const int64_t row_diff = row_high - row_low;
+    const int64_t stride00 = nb01 / ggml_type_size(src0->type);
 
     int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
@@ -1468,73 +25,39 @@ void ggml_cuda_op_mul_mat_q(
     // nrows_dst == nrows of the matrix that the kernel writes into
     const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
 
-    mmq_config_t mmq_config;
+    const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst};
 
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            mmq_config = MMQ_CONFIG_Q4_0;
+            mul_mat_q_case<GGML_TYPE_Q4_0>(args, stream);
             break;
         case GGML_TYPE_Q4_1:
-            mmq_config = MMQ_CONFIG_Q4_1;
+            mul_mat_q_case<GGML_TYPE_Q4_1>(args, stream);
             break;
         case GGML_TYPE_Q5_0:
-            mmq_config = MMQ_CONFIG_Q5_0;
+            mul_mat_q_case<GGML_TYPE_Q5_0>(args, stream);
             break;
         case GGML_TYPE_Q5_1:
-            mmq_config = MMQ_CONFIG_Q5_1;
+            mul_mat_q_case<GGML_TYPE_Q5_1>(args, stream);
             break;
         case GGML_TYPE_Q8_0:
-            mmq_config = MMQ_CONFIG_Q8_0;
+            mul_mat_q_case<GGML_TYPE_Q8_0>(args, stream);
             break;
         case GGML_TYPE_Q2_K:
-            mmq_config = MMQ_CONFIG_Q2_K;
+            mul_mat_q_case<GGML_TYPE_Q2_K>(args, stream);
             break;
         case GGML_TYPE_Q3_K:
-            mmq_config = MMQ_CONFIG_Q3_K;
+            mul_mat_q_case<GGML_TYPE_Q3_K>(args, stream);
             break;
         case GGML_TYPE_Q4_K:
-            mmq_config = MMQ_CONFIG_Q4_K;
+            mul_mat_q_case<GGML_TYPE_Q4_K>(args, stream);
             break;
         case GGML_TYPE_Q5_K:
-            mmq_config = MMQ_CONFIG_Q5_K;
+            mul_mat_q_case<GGML_TYPE_Q5_K>(args, stream);
             break;
         case GGML_TYPE_Q6_K:
-            mmq_config = MMQ_CONFIG_Q6_K;
-            break;
-        default:
-            GGML_ASSERT(false);
+            mul_mat_q_case<GGML_TYPE_Q6_K>(args, stream);
             break;
-    }
-
-    mmq_arch_config_t arch_config;
-    if (compute_capability >= CC_RDNA2) {
-        arch_config = mmq_config.rdna2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        arch_config = mmq_config.rdna1;
-    } else if (compute_capability >= CC_VOLTA) {
-        arch_config = mmq_config.ampere;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        arch_config = mmq_config.pascal;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (row_diff   + arch_config.y - 1) / arch_config.y;
-    const int block_num_y = (src1_ncols + arch_config.x - 1) / arch_config.x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, arch_config.nwarps, 1);
-
-    switch (src0->type) {
-        MMQ_SWITCH_CASE(4_0)
-        MMQ_SWITCH_CASE(4_1)
-        MMQ_SWITCH_CASE(5_0)
-        MMQ_SWITCH_CASE(5_1)
-        MMQ_SWITCH_CASE(8_0)
-        MMQ_SWITCH_CASE(2_K)
-        MMQ_SWITCH_CASE(3_K)
-        MMQ_SWITCH_CASE(4_K)
-        MMQ_SWITCH_CASE(5_K)
-        MMQ_SWITCH_CASE(6_K)
         default:
             GGML_ASSERT(false);
             break;
index 807817c4a715fa08331560de51a0394c23c85a84..6744cce6d785f7a951770b2780e47083c7c44dcf 100644 (file)
 #include "common.cuh"
+#include "vecdotq.cuh"
+
+#include <climits>
+#include <cstdint>
+
+typedef void (*load_tiles_mmq_t)(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
+typedef void (*vec_dot_mmq_t)(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0);
+
+struct tile_x_sizes {
+    int ql;
+    int dm;
+    int qh;
+    int sc;
+};
+
+// get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
+
+static constexpr __device__ int get_mmq_x_max_device() {
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    return 64;
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+#ifdef CUDA_USE_TENSOR_CORES
+    return MMQ_MAX_BATCH_SIZE;
+#else
+    return 128;
+#endif // CUDA_USE_TENSOR_CORES
+#else
+    return 64;
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
+
+// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+static constexpr __device__ int get_mmq_y_device(int mmq_x) {
+    return mmq_x >= 32 ? 128 : 64;
+}
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+static constexpr __device__ int get_mmq_y_device(int mmq_x) {
+    return mmq_x >= 32 ? 128 : 64;
+}
+#else
+static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
+    return 64;
+}
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0,                           0}
+#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0,                           0}
+#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0,                           0}
+#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0,                           0}
+#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0,                           0}
+#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0,                           mmq_y*WARP_SIZE/4 + mmq_y/4}
+#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4}
+#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0,                           mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0,                           mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0,                           mmq_y*WARP_SIZE/8 + mmq_y/8}
+
+#define GET_TILE_X_SIZES_BODY                           \
+    return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
+        type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 :    \
+        type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 :    \
+        type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 :    \
+        type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 :    \
+        type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K :    \
+        type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K :    \
+        type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K :    \
+        type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K :    \
+        type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K :    \
+        tile_x_sizes{0, 0, 0, 0}
+
+static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
+    GET_TILE_X_SIZES_BODY;
+}
+
+template <int mmq_y>
+static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) {
+    GET_TILE_X_SIZES_BODY;
+}
+
+// ------------------------------------------------------------
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    const int kbx  = threadIdx.x / QI4_0;
+    const int kqsx = threadIdx.x % QI4_0;
+
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
+        int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+            const float * x_dmf = (const float *) x_dm;
+
+            int u[2*VDR_Q4_0_Q8_1_MMQ];
+
+#pragma unroll
+            for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
+            }
+
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+                (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
+                y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    const int kbx  = threadIdx.x / QI4_1;
+    const int kqsx = threadIdx.x % QI4_1;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
+        int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+
+            int u[2*VDR_Q4_1_Q8_1_MMQ];
+
+#pragma unroll
+            for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
+            }
+
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+                (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
+                y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    const int kbx  = threadIdx.x / QI5_0;
+    const int kqsx = threadIdx.x % QI5_0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
+
+        const int ql = get_int_from_uint8(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
+
+        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
+        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
+        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
+        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
+        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
+        qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
+
+        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
+        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
+        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
+        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
+        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
+        qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
+        int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+            const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
+            const float * x_dmf = (const float *) x_dm;
+            const float * y_df  = (const float *) y_ds;
+
+            int u[2*VDR_Q5_0_Q8_1_MMQ];
+
+#pragma unroll
+            for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
+                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
+            }
+
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
+                (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+        }
+    }
+}
+
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    const int kbx  = threadIdx.x / QI5_1;
+    const int kqsx = threadIdx.x % QI5_1;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
+
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
+
+        int qs0 = (ql >>  0) & 0x0F0F0F0F;
+        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
+        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
+        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
+        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
+
+        int qs1 = (ql >>  4) & 0x0F0F0F0F;
+        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
+        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
+        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
+        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
+        int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+            const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1;
+
+            int u[2*VDR_Q5_1_Q8_1_MMQ];
+
+#pragma unroll
+            for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
+                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
+            }
+
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+                (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    const int kbx  = threadIdx.x / QI8_0;
+    const int kqsx = threadIdx.x % QI8_0;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
+        int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const float * x_dmf = (const float *) x_dm;
+            const float * y_df  = (const float *) y_ds;
+
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
+                (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
+                y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh);
+
+    const int kbx  = threadIdx.x / QI2_K;
+    const int kqsx = threadIdx.x % QI2_K;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
+        int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4);
+
+        x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4));
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kbx = k0 / QI2_K;
+            const int ky  = (k0 % QI2_K) * QR2_K;
+            const float * y_df = (const float *) y_ds;
+
+            int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
+
+            const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
+            const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
+
+#pragma unroll
+            for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
+                v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
+            }
+
+            const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
+
+            const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
+                v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+
+    const int kbx  = threadIdx.x / QI3_K;
+    const int kqsx = threadIdx.x % QI3_K;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
+        int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
+        int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2);
+
+        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+        x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2));
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
+
+        const int ksc = threadIdx.x % (QI3_K/4);
+
+        const int ksc_low = ksc % (QI3_K/8);
+        const int shift_low = 4 * (ksc / (QI3_K/8));
+        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+        const int ksc_high = QI3_K/8;
+        const int shift_high = 2 * ksc;
+        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+        const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
+
+        x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kbx  = k0 / QI3_K;
+            const int ky  = (k0 % QI3_K) * QR3_K;
+            const float * x_dmf = (const float *) x_dm;
+            const float * y_df  = (const float *) y_ds;
+
+            const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+
+            int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
+
+#pragma unroll
+            for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
+                const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
+                const int shift = 2 * ((ky % 32) / 8);
+                const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
+
+                const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
+                const int vlh = (vh << 2) & 0x04040404;
+
+                v[l] = __vsubss4(vll, vlh);
+            }
+
+            const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
+                v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh);
+
+    const int kbx  = 0;           // threadIdx.x / QI4_K
+    const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_K;  // == 1 if QK_K == 256
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
+        int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
+
+            const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
+                &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh);
+
+    const int kbx  = 0;           // threadIdx.x / QI5_K
+    const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
+        const int ky = QR5_K*kqsx;
+
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
+        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+        const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
+        const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
+
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_K;  // == 1 if QK_K == 256
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
+        int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
+
+            const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k0;
+            const int index_y = j * WARP_SIZE             + (QR5_K*k0) % WARP_SIZE;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
+                &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh);
+
+    const int kbx  = 0;           // threadIdx.x / QI6_K
+    const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
+        const int ky = QR6_K*kqsx;
+
+        const int ql = get_int_from_uint8(bxi->ql, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
+        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
+        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
+
+        const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
+        const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
+
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
+        int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const float * x_dmf = (const float *) x_dm;
+            const float * y_df  = (const float *) y_ds;
+
+            const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
+
+            const int index_x = i * (QR6_K*WARP_SIZE + 1) +  QR6_K*k0;
+            const int index_y = j * WARP_SIZE             + (QR6_K*k0) % WARP_SIZE;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
+                &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
+        }
+    }
+}
+
+// -------------------------------------------------------------------------------------------------------------------------------------
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
+struct mmq_type_traits;
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
+    static constexpr bool             need_sum   = true;
+    static constexpr int              vdr        = VDR_Q4_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
+    static constexpr bool             need_sum   = true;
+    static constexpr int              vdr        = VDR_Q4_1_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
+    static constexpr bool             need_sum   = false;
+    static constexpr int              vdr        = VDR_Q5_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
+    static constexpr bool             need_sum   = true;
+    static constexpr int              vdr        = VDR_Q5_1_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
+    static constexpr bool             need_sum   = false;
+    static constexpr int              vdr        = VDR_Q8_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
+    static constexpr bool             need_sum   = false;
+    static constexpr int              vdr        = VDR_Q2_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
+    static constexpr bool             need_sum   = false;
+    static constexpr int              vdr        = VDR_Q3_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
+    static constexpr bool             need_sum   = true;
+    static constexpr int              vdr        = VDR_Q4_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
+    static constexpr bool             need_sum   = true;
+    static constexpr int              vdr        = VDR_Q5_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
+    static constexpr bool             need_sum   = false;
+    static constexpr int              vdr        = VDR_Q6_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+    __launch_bounds__(WARP_SIZE*nwarps, 1)
+#else
+    __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2)
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+static __global__ void mul_mat_q(
+    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
+    const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) {
+
+    // Skip unused template specializations for faster compilation:
+    if (mmq_x > get_mmq_x_max_device()) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
+    constexpr int              qr         = ggml_cuda_type_traits<type>::qr;
+    constexpr int              qi         = ggml_cuda_type_traits<type>::qi;
+    constexpr int              mmq_y      = get_mmq_y_device(mmq_x);
+    constexpr bool             need_sum   = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::need_sum;
+    constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
+    constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
+
+    constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
+
+    extern __shared__ char data_mul_mat_q[];
+    int   * tile_x_ql = (int   *)  data_mul_mat_q;
+    half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
+    int   * tile_x_qh = (int   *) (tile_x_dm + txs.dm);
+    int   * tile_x_sc = (int   *) (tile_x_qh + txs.qh);
+    int   * tile_y_qs = (int   *) (tile_x_sc + txs.sc);          // [mmq_x * WARP_SIZE]
+    half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1];
+
+    const block_q8_1 * y = (const block_q8_1 *) yc;
+
+    const int blocks_per_row_x = ne00 / qk;
+    const int blocks_per_col_y = ne10 / QK8_1;
+    const int blocks_per_warp = WARP_SIZE / qi;
+
+    const int & ne1 = ne11;
+
+    const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
+
+    float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
+
+    for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
+
+        load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00);
+
+#pragma unroll
+        for (int kr = 0; kr < qr; ++kr) {
+            const int kqs = kr*WARP_SIZE + threadIdx.x;
+            const int kbxd = kqs / QI8_1;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_x; i0 += nwarps) {
+                const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses
+
+                const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd];
+
+                const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE;
+                tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
+            }
+
+#pragma unroll
+            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
+                const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
+                const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
+                const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1);
+
+                // if the sum is not needed it's faster to transform the scale to f32 ahead of time
+                const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds;
+                half2       * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
+                if (need_sum) {
+                    *dsi_dst = *dsi_src;
+                } else {
+                    float * dfi_dst = (float *) dsi_dst;
+                    *dfi_dst = __low2float(*dsi_src);
+                }
+            }
+
+            __syncthreads();
+
+// #pragma unroll // unrolling this loop causes too much register pressure
+            for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
+                vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0);
+            }
+
+            __syncthreads();
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
+
+        if (j >= ne1) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
+
+            if (need_check && i >= ne0) {
+                continue;
+            }
+
+            dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+        }
+    }
+}
+
+struct mmq_args {
+    const char * x; const char * y; float * dst;
+    int64_t ne00; int64_t ne01; int64_t stride00;
+    int64_t ne10; int64_t ne11;
+    int64_t ne0;
+};
+
+template <ggml_type type, int mmq_x, int nwarps>
+static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
+    const int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
+    const int mmq_y = get_mmq_y_host(cc, mmq_x);
+
+    const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
+    const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
+    const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
+    const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
+    const int shmem = shmem_x + shmem_y;
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+    static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+    if (!shmem_limit_raised[id]) {
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        shmem_limit_raised[id] = true;
+    }
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+
+    if (args.ne01 % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
+            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
+    } else {
+        const bool need_check = true;
+        mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
+            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
+    }
+}
+
+template <ggml_type type>
+void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
+    const int id = ggml_cuda_get_device();
+    const int nsm = ggml_cuda_info().devices[id].nsm;
+    const int cc  = ggml_cuda_info().devices[id].cc;
+
+    const int mmq_x_max = get_mmq_x_max_host(cc);
+    const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
+    const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
+
+    int mmq_x_best  = 0;
+    int nwaves_best = INT_MAX;
+
+    for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
+        const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
+        const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
+
+        if (nwaves < nwaves_best) {
+            mmq_x_best  = mmq_x;
+            nwaves_best = nwaves;
+        }
+    }
+
+    switch (mmq_x_best) {
+        case   8:
+            launch_mul_mat_q<type,   8, 4>(args, stream);
+            break;
+        case  16:
+            launch_mul_mat_q<type,  16, 8>(args, stream);
+            break;
+        case  24:
+            launch_mul_mat_q<type,  24, 8>(args, stream);
+            break;
+        case  32:
+            launch_mul_mat_q<type,  32, 8>(args, stream);
+            break;
+        case  40:
+            launch_mul_mat_q<type,  40, 8>(args, stream);
+            break;
+        case  48:
+            launch_mul_mat_q<type,  48, 8>(args, stream);
+            break;
+        case  56:
+            launch_mul_mat_q<type,  56, 8>(args, stream);
+            break;
+        case  64:
+            launch_mul_mat_q<type,  64, 8>(args, stream);
+            break;
+        case  72:
+            launch_mul_mat_q<type,  72, 8>(args, stream);
+            break;
+        case  80:
+            launch_mul_mat_q<type,  80, 8>(args, stream);
+            break;
+        case  88:
+            launch_mul_mat_q<type,  88, 8>(args, stream);
+            break;
+        case  96:
+            launch_mul_mat_q<type,  96, 8>(args, stream);
+            break;
+        case 104:
+            launch_mul_mat_q<type, 104, 8>(args, stream);
+            break;
+        case 112:
+            launch_mul_mat_q<type, 112, 8>(args, stream);
+            break;
+        case 120:
+            launch_mul_mat_q<type, 120, 8>(args, stream);
+            break;
+        case 128:
+            launch_mul_mat_q<type, 128, 8>(args, stream);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+}
+
+#define DECL_MMQ_CASE(type)                                                        \
+    template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
+
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
+
+// -------------------------------------------------------------------------------------------------------------------------
 
 void ggml_cuda_op_mul_mat_q(
     ggml_backend_cuda_context & ctx,
index d405c124ed657ea78eab84b3f065ed34589bf127..e8d157169544f75cee2d59963c90b5e768a37e0c 100644 (file)
@@ -1,9 +1,47 @@
 #include "mmvq.cuh"
 #include "vecdotq.cuh"
 
-typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
+
+static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
+        type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
+        type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
+        type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
+        type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
+        type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
+        type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
+        type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
+        type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
+        type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
+        type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
+        type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
+        type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
+        type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
+        type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
+        type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
+        type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
+        type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
+        type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
+        nullptr;
+}
+
+static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
+        type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ :
+        1;
+}
 
-template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
+template <ggml_type type, int ncols_y>
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 // tell the compiler to use as many registers as it wants, see nwarps definition below
 __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
@@ -12,6 +50,12 @@ static __global__ void mul_mat_vec_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
 
+    constexpr int qk  = ggml_cuda_type_traits<type>::qk;
+    constexpr int qi  = ggml_cuda_type_traits<type>::qi;
+    constexpr int vdr = get_vdr_mmvq(type);
+
+    constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
+
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
     constexpr int nwarps              = 1;
     constexpr int rows_per_cuda_block = 1;
@@ -29,7 +73,6 @@ static __global__ void mul_mat_vec_q(
 // partial sum for each thread
     float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
 
-    const block_q_t  * x = (const block_q_t  *) vx;
     const block_q8_1 * y = (const block_q8_1 *) vy;
 
     for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@@ -42,8 +85,7 @@ static __global__ void mul_mat_vec_q(
         for (int j = 0; j < ncols_y; ++j) {
 #pragma unroll
             for (int i = 0; i < rows_per_cuda_block; ++i) {
-                tmp[j][i] += vec_dot_q_cuda(
-                    &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
+                tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
             }
         }
     }
@@ -81,12 +123,12 @@ static __global__ void mul_mat_vec_q(
     }
 }
 
-template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
+template <ggml_type type>
 static void mul_mat_vec_q_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    GGML_ASSERT(ncols_x % qk == 0);
+    GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
     GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
 
     int id = ggml_cuda_get_device();
@@ -124,36 +166,28 @@ static void mul_mat_vec_q_cuda(
 
     switch (ncols_y) {
         case 1:
-            mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 2:
-            mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 3:
-            mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 4:
-            mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 5:
-            mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 6:
-            mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 7:
-            mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 8:
-            mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         default:
             GGML_ASSERT(false);
@@ -165,152 +199,133 @@ static void mul_mat_vec_q4_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q4_1_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q5_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q5_1_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q8_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q2_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q3_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q4_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q5_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q6_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq2_xxs_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq2_xs_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq2_s_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq3_xxs_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq1_s_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq1_m_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq4_nl_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq4_xs_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq3_s_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 void ggml_cuda_op_mul_mat_vec_q(
index d7f1034751cc1d9c9b7c9237da72acad82875cd7..6696a238476d8d9cd3aa4438928ad8fd67a66a35 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index f3d8d2eda49912e4a5673603e0bfa72078277ad9..dd070db2853f549fd63461fe1a1306ea0be6b7da 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 9beb05ca23597262263f716262013d98ad963b49..54dcde6f523247ed50c2fccb4d2aabfe252b26e9 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 0c163dcba06139254391c797b290821e57d34acd..4ec22f791912d13f5e3904dd577b3ed89961d20e 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 3980167b3db919333aab2e9089018357cbdffc6b..3c15bf7f0ef16054c2965bf4bef3e9e5fd8c1313 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index fe099921d08db37d1eda50c794a9bdd5c4a25258..7e61b5fdcdbca4f21a3d24a28c58e30c7aa8fcc8 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index d4d5e7999e3939c0903b99e18a753ce1f9f4b15c..fdb15b580cff8887f6fbc555883fc3c9543cdfde 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index f08b10c4d53c13f5ca137993f02ad7c6897a3d2d..0f7c417d2c0c8faa9f1ffcda4c88e78771961389 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index e8c3f8adc2a1e03fd542cbc791b3d07e2cd02245..851f33c43f0407869caa54952ca3bc34763f1cd4 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index c01416a13f6ac888e55119d1107933b80d3304ab..763809cbeb44c6aba251240c41a45833aa5d5416 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 46615f281cb18af1b904ae54a5539e84d6ecba3d..f2a276e50e5fa6e3ee1acd3cdff10fc027146726 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 72dcc1a2fe5636fecd255d26d61379317b81cca6..cb227f6f5ce1f47b62726bfed6ee55e121f35101 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 9fa8a377d0c6f4e80795f61ae31a71cf17460876..97ac0520c71d16413b07d0ae25854662d0ddc0e4 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 20ea86c6dcb89bde2cd0baf7bdeff16d47700d6b..c772b42634fe6b3d3c26d6c88e55cd33906bb017 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index ed815957cb212725652396e9e8344d6dfa3bb17b..5cb7430819e4e50f8a387b731d20db290c4f3142 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index bbe9e6a1c8e613348c8fb568c0c18d1cb98a2a1f..98a709d171446af665703a701b18859a9a95e9c4 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index d12a616996a9b805c64fb9ef51cd1a4c3088546b..4f2f947ae81e66989283b4db6651cd0232ef5aec 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 1e901afcbbc996235488535595a1379eccbcfcf8..11f96b6f65ceedcc93a112bcd5f0831594bdf9ed 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index a3f98ce37456c9bd421d3fb9368ba41f7f774678..b39bdc0611c0d734a6d2ff7b669b92d81f76a719 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 1bae97243a062998db4180343d8b10f58810f2fd..bbd6a2c7f491cbda161dbbdd9ade87461dff5630 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 7258e9775da099765d0631c5c2b5755ddbfa8dbf..9d84ff2b1917510a6e5ece7a645ac14e4b160572 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 08435c005e81cf0bc30c48f37665601ddd66ebab..bc8a5bff684ff5d162559bdc4106e1fc969991ad 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 17864e8e9cdf30549f433f85444957e3884efad1..a679100c8380727903c4e91f0bdd74e666686d47 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 9239138c962ae845b41072f9757c091a2fb1abc9..8f21bccf7f8da823bd44d20d1e36cad7339a727e 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index e387d9c1da3e2a97318606af5bb2a369098aa8f1..858b00fd741918d43c981da65d278f8851632883 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index d69d3bbd683aa232aaf246e349e28440dba877aa..0fc8011fac5fc3ce61b3d1f11c9f7975cc4d636a 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 61a4788166ed7b11b4fb42adfdf3bba2576d5abf..261fdf623e098b6858089b4f94985c330599caba 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 89995080ac5f7d07609adc782ad9e5ee04dda64c..0fb8247383063443f250060a427e0efdf1078930 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 9e6a58dff675953971aedd896dc528048f885a2a..a9d9d089bd314e5a15dbede60bc4cde319454776 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 153cbfd86c82d7ead2a2cb669d0d7917a8af1d6d..7d7b27920aa3e790ac09d8aad804f3a4ee8dd590 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 09d5765582c8f8854846e7aae79a8fdba9c088b5..a092ee2d5095786a64b58ec84443b6d6231d3216 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 3e3c91e68c2196a866a9e3f45e8ba6d070a59128..db55927a19457ee663177e6f54f6d39043a92355 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 7b973058f4a891441f25d8d715cf93ed4a2d0454..c3c21cefae047ebd2f3868f95ede5308bafd1fa0 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index a43a475d44439241f62ffc0b6009521a7cd17182..35dd9f520802cb1fbb6b1e9192aa834a2c9c3ef6 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 5b570c0a3df41ec124e88ed248782ed39b1cebe0..050c22ac7c6c7a2c4586ab8906e1464c8fa047c3 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index bf2cc684ef459769821030f61da5b1a3b954e706..de4866c5e65cef37096ff390b755c99ae40efb8c 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 7428e45ea2621769424e1f2e0eeb9dac0d96948e..57a10bc4be4a3536320cd95e077d11a92697cf65 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 4aee830de1228851c436aafe48cc0d7cf1076f1f..e0f08b46a7e353952edec92292861fe6d7ad4e94 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 36acb631915183f7b5c79107de63c8596ed02022..1c8e8a467a8aa82bc66692df54c1bb8d039ad3c3 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index a4090c39055070bd4453ae181070e486a9664e6e..cefed83fb9562fcbc32430c1992c99d89e13d03b 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 17b6b2d117b52b5e9d8e54a587807d20478161d8..aede6e358819586781ee8eb56862f0b3eb41b54b 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 549e1cea1f379a9354d93d8e716704f0cc66eebb..1a1a92c788fbd9fdb71f67110446ca6902da6eac 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 66bcd820f9a203a80d9ff7dac1816f04dc1a5f35..ad667473d110b8ee7c6f9ad12b8d5b521b8befef 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f16.cuh"
 
index 15933a29977d077bc87f6cdcface30f27837e43c..c499f455da97186534177715e7692c96a120ea63 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 8aa78558395167425319f0c6acb6ed82e83c9ccd..8286ebf373627341e693f2d384d1073551937e69 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index bde3924fd7a28782673f88ea1b5e66a96b800d5a..4587868825d21f8ff6a2d9af495e71ad159faad6 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 1708181c15874db7705bdbfa4ddac6c8544dd464..d89103ce0c68fa6ab2c9daa9c8093d7cc69d4e9f 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 30fa6fa4cebf0c4a465b93acc45df4c9574f5055..bb75fd42ff17d7135510efccff3e4069a66034a5 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 69673d50f4b3ca47cff1d3c00f1838bd6b5b242d..b1629817e79e39f110ffb3c4c363d76efe058893 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index d8b2b2e18496e85a58f12c9ce8b9b698aea955d2..d8657604dab8035e2e9176bb985247bd26c2762a 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 01cce7ab50a7a03c8b5c085d15a9f8ff4087bb3f..2e5bd2f1a3accd948996c15568cf763698d1cd1a 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index fd5563b395a5f317bb9931f9d726a820a7996d45..be5f302d9f1d41d4ae8d35b3eb331fb6b1bffa1f 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index b13cc4a0c7e8d48f8031abd9f6b1958135fba78b..8dd91cd72eb608fa7077cfe34987d5dd58545286 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 86f1fc63701f32261ce47849e4e283827db5e821..4cb791502a157a9785d38e3169c9b8971ddbac7e 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 26e7df4be61f16ff3fc4158d961922ffd4006d08..09dea426736e930394961b99c8417f2a26fbc0b1 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index e4fda8952c75fefeb70f3ab4fc13823656058289..0fbb607694f2543149d6ab738cf57a4d790e1b9b 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index bd15117b40db702d16d96684cfda2e1413fb2dc1..2aeab83b20d21b1d1b7349c4136863dbea4abf43 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index cb6c6a7606aace4ddcfd76ae0c6fc67a8d005ac9..599415b4947411d8bad1f15899522c42b1efac9d 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 201b6641de102840b6807f5d9ff3af768847625c..e4f8e3083bb6bec6b78b0dec443e174b8a823d88 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 6da57a44aea08e9ff4eeb1f026f0a8d04805ccac..34d166527e93ae9f8ffa61e28877976431d2b8d4 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 47623c9bff42c71b3aff0b7c65f99e713e29ca81..4bebef45a37cb2850cc2301baaf6bfbd2c381dc2 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 82c6861d24f7a5545967fa3e6ab1550c30faf184..326468da2fb240baa2aed69c49fda4c9a137dba4 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 24a80c2b042f57f890a74bb81325d06ba589a159..511b58f4ecc725c3ed7a54b7b8b3943f4f28f20f 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index b95eaf7e18ccb90ba219f8ff69ec8fd140c8750d..d9906d142e159a3e179781b1e8b0c1c786ba1aef 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 275f2efccd18af973eaa176764e78313c5f4affa..f61c183abbaf79e350ff7f92c343f2a9655dd427 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 3673f7fd55316ad04fb47155babe8da0c1a04af5..c10450fd29e7691e92aad6e217efd32768e36ced 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 2c4d59947b1eb29b5757daafb7661cc9eb45c6dd..2d5cb195c41dce04ead664bb17b2ffaf5b57a94a 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 2457cdf3fe77db8af413b87f58f1ce6b6005599d..b384f34d7d9214a4c7db2d3993bfe8a57dfd87b9 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index b3b411ed368b8a142b1a3c950da1bb977caf5f9e..446e293b16edcf38525db2760d0cbdf340a59cfb 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index b7f308a4d0b7ed5851e596fb948a3b8d167df0d6..6f430298899c7d469a80026a6aabb3b81814d83f 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 7396866975b97e0b156d357add07bce4fcbe98ed..1cd8ba88fd6503fab0bfe2da050e9971b0ff1bdb 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 708d03113e5005c7387a6f7eacb840da7dfa329f..1ee2eab65a1c962e162b43d2158576bd06f42784 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index df891be6031ce7249af189517d07744ca9ee594c..2bc77816a5d4e4596bce991cc0f883732d7302ee 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index f49b6d1f9ef8f0a33ee348fa72b1f57e48c487c2..d55ced08bc940e9953eced46b09e9f4e6fd9b0f8 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 1de92148b714f7f03fb7951eba909d3f2e69875c..8361e99c4e4a4636f50df6b436ebc9ed5680d0ab 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 7a1ba7f8de4b76f3826e1b47a9f7bdb59f8e7bfc..7507a67c4c5e9061fa2909cfebc7d8c7d1ec4db0 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 25493e4bac9ee103523168e30f6db3c8e9756edb..61f050b235ff222d533e3ff7600070d6c39bfd0b 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 3cd650c7bc3e4874c365ceee2962c213eb38c337..d4a49d9c9912a37128e2bd07021d4a651f970112 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 88ffa43d6f1617da9ae3021139358f188999a03f..d14627897621185aac64c0f98364a57d30b32631 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 8c7bac6c2d7d9babd06c4bac63117f27cdbe6029..e73f917a1f1867590a67eb27410c76bfb44c169f 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index a28f62e7b82025f5322024b2362b4dfbcb95f2c3..d40825dfc21f035c3cf10626aeb70926c69374d6 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index d39838b96fa9d7a59e07b8b7daeec66559467fc6..b5c6869f4ec42b685ae8a39fb5b2d865608100e4 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 834d40f6c83929ba42ca0341a00ed2a7486f9372..4e21b0ccaef1622d98057ae2d727302f530ff4bf 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index f7d54668b9d86aa278de36fa223fff6f3c32abf6..2eac321b370dfd54fefd322cfe6a3f6770e2ce11 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 59e00ad83414cdb59e31222a0786a719654ffca1..f7d2c3b4e0a12cfd8cdf4954e2040c4eeb5088d3 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index 6e63893debd6be367f7970b4a435fc3094174f60..a013f400bd33bfd40eb975d6e3d28f80a3fe9473 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-vec-f32.cuh"
 
index ca356ad6c9ca0452fa53456744c28d4a3c7045a4..2d94e65c28c298e512e99f2783f491560b81c591 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-wmma-f16.cuh"
 
index 430ee64eb29a59a1cb440dcd12eb9d4ac754db14..c3d9df3c443137185ba1e797a5b6079d025044c1 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-wmma-f16.cuh"
 
index d421d17ccc5fd0436c9c979f589edd08dcf409e1..bb680e401f7da5f9013c6f4aca1f452ec40c85b5 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-wmma-f16.cuh"
 
index deacd5f58eec9a02ff2f2187fc815a0a3429003d..073f71b1f3e267b5d0d1358215abc593df134cc9 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-wmma-f16.cuh"
 
index 282896733473f773b7931b5c0ceb5f10b6fbe705..d30710c5fa21b66ce341ff76921c93c0efded187 100644 (file)
@@ -1,4 +1,4 @@
-// This file has been autogenerated by generate-variants.py, do not edit manually.
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
 #include "../fattn-wmma-f16.cuh"
 
index ee5b460e07986c3355948986e1b9483470fc7c84..ea58d09680231f3ce1482452492af9fdb5567c6d 100755 (executable)
@@ -20,6 +20,18 @@ SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_
 
 SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
 
+TYPES_MMQ = [
+    "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
+    "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K"
+]
+
+SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE({type});
+"""
+
 
 def get_short_name(long_quant_name):
     return long_quant_name.replace("GGML_TYPE_", "").lower()
@@ -57,3 +69,7 @@ for kq_acc_t in ["half", "float"]:
                 if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
                     continue
                 f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
+
+for type in TYPES_MMQ:
+    with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
+        f.write(SOURCE_MMQ.format(type=type))
diff --git a/ggml-cuda/template-instances/mmq-instance-q2_k.cu b/ggml-cuda/template-instances/mmq-instance-q2_k.cu
new file mode 100644 (file)
index 0000000..6415369
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q2_K);
diff --git a/ggml-cuda/template-instances/mmq-instance-q3_k.cu b/ggml-cuda/template-instances/mmq-instance-q3_k.cu
new file mode 100644 (file)
index 0000000..ffb6213
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q3_K);
diff --git a/ggml-cuda/template-instances/mmq-instance-q4_0.cu b/ggml-cuda/template-instances/mmq-instance-q4_0.cu
new file mode 100644 (file)
index 0000000..0c0b0c8
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_0);
diff --git a/ggml-cuda/template-instances/mmq-instance-q4_1.cu b/ggml-cuda/template-instances/mmq-instance-q4_1.cu
new file mode 100644 (file)
index 0000000..ee67f69
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_1);
diff --git a/ggml-cuda/template-instances/mmq-instance-q4_k.cu b/ggml-cuda/template-instances/mmq-instance-q4_k.cu
new file mode 100644 (file)
index 0000000..9eeb3cd
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_K);
diff --git a/ggml-cuda/template-instances/mmq-instance-q5_0.cu b/ggml-cuda/template-instances/mmq-instance-q5_0.cu
new file mode 100644 (file)
index 0000000..cc57fb9
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_0);
diff --git a/ggml-cuda/template-instances/mmq-instance-q5_1.cu b/ggml-cuda/template-instances/mmq-instance-q5_1.cu
new file mode 100644 (file)
index 0000000..721ac79
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_1);
diff --git a/ggml-cuda/template-instances/mmq-instance-q5_k.cu b/ggml-cuda/template-instances/mmq-instance-q5_k.cu
new file mode 100644 (file)
index 0000000..a2e90ff
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_K);
diff --git a/ggml-cuda/template-instances/mmq-instance-q6_k.cu b/ggml-cuda/template-instances/mmq-instance-q6_k.cu
new file mode 100644 (file)
index 0000000..470938f
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q6_K);
diff --git a/ggml-cuda/template-instances/mmq-instance-q8_0.cu b/ggml-cuda/template-instances/mmq-instance-q8_0.cu
new file mode 100644 (file)
index 0000000..974477b
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q8_0);
index df9752390509dd3d4caa3ba17f49678aafcd5785..b9573a7c7d053d4e5707d091f2898bc414bd937d 100644 (file)
@@ -566,9 +566,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
 }
 
 static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx;
 
     int v[VDR_Q4_0_Q8_1_MMVQ];
     int u[2*VDR_Q4_0_Q8_1_MMVQ];
@@ -585,9 +585,9 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
 
 
 static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx;
 
     int v[VDR_Q4_1_Q8_1_MMVQ];
     int u[2*VDR_Q4_1_Q8_1_MMVQ];
@@ -603,9 +603,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
+    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx;
 
     int vl[VDR_Q5_0_Q8_1_MMVQ];
     int vh[VDR_Q5_0_Q8_1_MMVQ];
@@ -623,9 +623,9 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
+    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx;
 
     int vl[VDR_Q5_1_Q8_1_MMVQ];
     int vh[VDR_Q5_1_Q8_1_MMVQ];
@@ -643,9 +643,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
+    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx;
 
     int v[VDR_Q8_0_Q8_1_MMVQ];
     int u[VDR_Q8_0_Q8_1_MMVQ];
@@ -660,9 +660,9 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q2_K * bq2_K = (const block_q2_K *) vbq;
+    const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx;
 
     const int bq8_offset = QR2_K * (iqs / QI8_1);
     const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
@@ -683,9 +683,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+    const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx;
 
     const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
     const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
@@ -710,9 +710,9 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+    const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx;
 
     int    v[2];
     int    u[2*QR4_K];
@@ -756,9 +756,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+    const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx;
 
     int   vl[2];
     int   vh[2];
@@ -802,9 +802,9 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q6_K * bq6_K = (const block_q6_K *) vbq;
+    const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx;
 
     const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
     const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
@@ -828,8 +828,8 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
-    const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+    const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx;
 
 #if QR2_XXS == 8
     const int ib32 = iqs;
@@ -872,9 +872,9 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
+    const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx;
 
     const int ib32 = iqs;
     const uint16_t * q2 = bq2->qs + 4*ib32;
@@ -911,9 +911,9 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
 
 // TODO
 static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
+    const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx;
 
     const int ib32 = iqs;
     const int8_t  * q8 = bq8_1[ib32].qs;
@@ -951,9 +951,9 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
+    const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq + kbx;
 
     const int ib32 = iqs;
     const uint8_t  * q3 = bq2->qs + 8*ib32;
@@ -981,9 +981,9 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
 
 // TODO: don't use lookup table for signs
 static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
+    const block_iq3_s * bq2 = (const block_iq3_s *) vbq + kbx;
 
     const int ib32 = iqs;
     const uint8_t  * qs = bq2->qs + 8*ib32;
@@ -1008,8 +1008,8 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
-    const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+    const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
 
     const int ib32 = iqs;
     int sumi = 0;
@@ -1039,8 +1039,8 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
-    const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+    const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx;
 
     const int ib32 = iqs;
     int   sumi[2] = {0, 0};
@@ -1094,9 +1094,9 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4
 #endif
 
 static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
+    const block_iq4_nl * bq = (const block_iq4_nl *) vbq + kbx;
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
@@ -1128,10 +1128,10 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
+    const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx;
     const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
 
     // iqs is 0...7
@@ -1149,6 +1149,6 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
     }
     return d * (sumi1 + sumi2);
 #else
-    return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
+    return vec_dot_iq4_xs_q8_1(vbq, bq8_1, kbx, iqs);
 #endif
 }