]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: Fixed OpenLLaMA 3b mmq, reduced compile time (#2590)
authorJohannes Gäßler <redacted>
Sat, 12 Aug 2023 22:24:45 +0000 (00:24 +0200)
committerGitHub <redacted>
Sat, 12 Aug 2023 22:24:45 +0000 (00:24 +0200)
CMakeLists.txt
ggml-cuda.cu

index d085bc835cf38b886d9a2f36ca717e1c25232e3f..dff4942cdb9b2703cc5d7e145f538806c6244dc1 100644 (file)
@@ -69,7 +69,6 @@ option(LLAMA_BLAS                            "llama: use BLAS"
 set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
 option(LLAMA_CUBLAS                          "llama: use CUDA"                                  OFF)
 #option(LLAMA_CUDA_CUBLAS                     "llama: use cuBLAS for prompt processing"          OFF)
-set(LLAMA_CUDA_MMQ_Y       "64" CACHE STRING "llama: y tile size for mmq CUDA kernels")
 option(LLAMA_CUDA_FORCE_DMMV                 "llama: use dmmv instead of mmvq CUDA kernels"     OFF)
 set(LLAMA_CUDA_DMMV_X      "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
 set(LLAMA_CUDA_MMV_Y        "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
@@ -256,7 +255,6 @@ if (LLAMA_CUBLAS)
 #        if (LLAMA_CUDA_CUBLAS)
 #            add_compile_definitions(GGML_CUDA_CUBLAS)
 #        endif()
-        add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y})
         if (LLAMA_CUDA_FORCE_DMMV)
             add_compile_definitions(GGML_CUDA_FORCE_DMMV)
         endif()
index 6390b1158b6a6c5b792ff8bd9f1731c7c499596f..11f67aec8e47bc48986822fb770d81c348f3e2ee 100644 (file)
@@ -1399,6 +1399,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
     // second part effectively subtracts 8 from each quant value
     return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1436,6 +1437,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
     // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
     return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1471,6 +1473,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
     // second part effectively subtracts 16 from each quant value
     return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1516,6 +1519,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
     return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
 
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1537,6 +1541,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
 
     return d8_0*d8_1 * sumi;
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1567,6 +1572,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
     // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
     return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1602,6 +1608,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
 
     return dm2f.x*sumf_d - dm2f.y*sumf_m;
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1639,6 +1646,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
 
     return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1679,6 +1687,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
 
     return d3 * sumf;
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1704,6 +1713,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
 
     return d3*d8 * sumi;
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1737,6 +1747,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
     return dm4f.x*sumf_d - dm4f.y*sumf_m;
 
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1772,6 +1783,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
     return dm4f.x*sumf_d - dm4f.y*sumf_m;
 
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1812,6 +1824,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl(
     return dm5f.x*sumf_d - dm5f.y*sumf_m;
 
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1842,6 +1855,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
 
     return d*sumf;
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -1873,6 +1887,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
     return d6 * sumf_d;
 
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
@@ -2722,6 +2737,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     return dall * sumf_d - dmin * sumf_m;
 
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 
@@ -2905,6 +2921,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
     return d * sumf_d;
 
 #else
+    assert(false);
     return 0.0f; // only to satisfy the compiler
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 
@@ -3135,7 +3152,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
 
 template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
               allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
-static __global__ void mul_mat_q(
+static __device__ __forceinline__ void mul_mat_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
@@ -3150,7 +3167,6 @@ static __global__ void mul_mat_q(
 
     const int row_dst_0 = blockIdx.x*mmq_y;
     const int & row_x_0 = row_dst_0;
-    const int row_dst = row_dst_0 + threadIdx.x;
 
     const int col_dst_0 = blockIdx.y*mmq_x;
     const int & col_y_0 = col_dst_0;
@@ -3223,11 +3239,7 @@ static __global__ void mul_mat_q(
         }
     }
 
-
-    if (row_dst >= nrows_dst) {
-        return;
-    }
-
+#pragma unroll
     for (int j = 0; j < mmq_x; j += nwarps) {
         const int col_dst = col_dst_0 + j + threadIdx.y;
 
@@ -3235,12 +3247,359 @@ static __global__ void mul_mat_q(
             return;
         }
 
+#pragma unroll
         for (int i = 0; i < mmq_y; i += WARP_SIZE) {
-            dst[col_dst*nrows_dst + row_dst + i] = sum[i/WARP_SIZE][j/nwarps];
+            const int row_dst = row_dst_0 + threadIdx.x + i;
+
+            if (row_dst >= nrows_dst) {
+                continue;
+            }
+
+            dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
         }
     }
 }
 
+#define  MMQ_X_Q4_0_AMPERE 64
+#define  MMQ_Y_Q4_0_AMPERE 128
+#define NWARPS_Q4_0_AMPERE 4
+#define  MMQ_X_Q4_0_PASCAL 64
+#define  MMQ_Y_Q4_0_PASCAL 64
+#define NWARPS_Q4_0_PASCAL 8
+
+template <bool need_check> static __global__ void mul_mat_q4_0(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= CC_TURING
+    const int mmq_x  =  MMQ_X_Q4_0_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q4_0_AMPERE;
+    const int nwarps = NWARPS_Q4_0_AMPERE;
+
+    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+        load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int mmq_x  =  MMQ_X_Q4_0_PASCAL;
+    const int mmq_y  =  MMQ_Y_Q4_0_PASCAL;
+    const int nwarps = NWARPS_Q4_0_PASCAL;
+
+    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+        load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
+    (void) vec_dot_q4_0_q8_1_mul_mat;
+    assert(false);
+#endif // __CUDA_ARCH__ >= CC_TURING
+}
+
+#define  MMQ_X_Q4_1_AMPERE 64
+#define  MMQ_Y_Q4_1_AMPERE 128
+#define NWARPS_Q4_1_AMPERE 4
+#define  MMQ_X_Q4_1_PASCAL 64
+#define  MMQ_Y_Q4_1_PASCAL 64
+#define NWARPS_Q4_1_PASCAL 8
+
+template <bool need_check> static __global__ void mul_mat_q4_1(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= CC_TURING
+    const int mmq_x  =  MMQ_X_Q4_1_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q4_1_AMPERE;
+    const int nwarps = NWARPS_Q4_1_AMPERE;
+
+    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+        load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int mmq_x  =  MMQ_X_Q4_1_PASCAL;
+    const int mmq_y  =  MMQ_Y_Q4_1_PASCAL;
+    const int nwarps = NWARPS_Q4_1_PASCAL;
+
+    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+        load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
+    (void) vec_dot_q4_1_q8_1_mul_mat;
+    assert(false);
+#endif // __CUDA_ARCH__ >= CC_TURING
+}
+
+#define  MMQ_X_Q5_0_AMPERE 128
+#define  MMQ_Y_Q5_0_AMPERE 64
+#define NWARPS_Q5_0_AMPERE 4
+#define  MMQ_X_Q5_0_PASCAL 64
+#define  MMQ_Y_Q5_0_PASCAL 64
+#define NWARPS_Q5_0_PASCAL 8
+
+template <bool need_check> static __global__ void mul_mat_q5_0(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= CC_TURING
+    const int mmq_x  =  MMQ_X_Q5_0_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q5_0_AMPERE;
+    const int nwarps = NWARPS_Q5_0_AMPERE;
+
+    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+        load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int mmq_x  =  MMQ_X_Q5_0_PASCAL;
+    const int mmq_y  =  MMQ_Y_Q5_0_PASCAL;
+    const int nwarps = NWARPS_Q5_0_PASCAL;
+
+    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+        load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
+    (void) vec_dot_q5_0_q8_1_mul_mat;
+    assert(false);
+#endif // __CUDA_ARCH__ >= CC_TURING
+}
+
+#define  MMQ_X_Q5_1_AMPERE 128
+#define  MMQ_Y_Q5_1_AMPERE 64
+#define NWARPS_Q5_1_AMPERE 4
+#define  MMQ_X_Q5_1_PASCAL 64
+#define  MMQ_Y_Q5_1_PASCAL 64
+#define NWARPS_Q5_1_PASCAL 8
+
+template <bool need_check> static __global__ void mul_mat_q5_1(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= CC_TURING
+    const int mmq_x  =  MMQ_X_Q5_1_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q5_1_AMPERE;
+    const int nwarps = NWARPS_Q5_1_AMPERE;
+
+    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+        load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int mmq_x  =  MMQ_X_Q5_1_PASCAL;
+    const int mmq_y  =  MMQ_Y_Q5_1_PASCAL;
+    const int nwarps = NWARPS_Q5_1_PASCAL;
+
+    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+        load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
+    (void) vec_dot_q5_1_q8_1_mul_mat;
+    assert(false);
+#endif // __CUDA_ARCH__ >= CC_TURING
+}
+
+#define  MMQ_X_Q8_0_AMPERE 128
+#define  MMQ_Y_Q8_0_AMPERE 64
+#define NWARPS_Q8_0_AMPERE 4
+#define  MMQ_X_Q8_0_PASCAL 64
+#define  MMQ_Y_Q8_0_PASCAL 64
+#define NWARPS_Q8_0_PASCAL 8
+
+template <bool need_check> static __global__ void mul_mat_q8_0(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= CC_TURING
+    const int mmq_x  =  MMQ_X_Q8_0_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q8_0_AMPERE;
+    const int nwarps = NWARPS_Q8_0_AMPERE;
+
+    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+        load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int mmq_x  =  MMQ_X_Q8_0_PASCAL;
+    const int mmq_y  =  MMQ_Y_Q8_0_PASCAL;
+    const int nwarps = NWARPS_Q8_0_PASCAL;
+
+    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+        load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
+    (void) vec_dot_q8_0_q8_1_mul_mat;
+    assert(false);
+#endif // __CUDA_ARCH__ >= CC_TURING
+}
+
+#define  MMQ_X_Q2_K_AMPERE 64
+#define  MMQ_Y_Q2_K_AMPERE 128
+#define NWARPS_Q2_K_AMPERE 4
+#define  MMQ_X_Q2_K_PASCAL 64
+#define  MMQ_Y_Q2_K_PASCAL 64
+#define NWARPS_Q2_K_PASCAL 8
+
+template <bool need_check> static __global__ void mul_mat_q2_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= CC_TURING
+    const int mmq_x  =  MMQ_X_Q2_K_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q2_K_AMPERE;
+    const int nwarps = NWARPS_Q2_K_AMPERE;
+
+    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+        load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int mmq_x  =  MMQ_X_Q2_K_PASCAL;
+    const int mmq_y  =  MMQ_Y_Q2_K_PASCAL;
+    const int nwarps = NWARPS_Q2_K_PASCAL;
+
+    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+        load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
+    (void) vec_dot_q2_K_q8_1_mul_mat;
+    assert(false);
+#endif // __CUDA_ARCH__ >= CC_TURING
+}
+
+#define  MMQ_X_Q3_K_AMPERE 128
+#define  MMQ_Y_Q3_K_AMPERE 128
+#define NWARPS_Q3_K_AMPERE 4
+#define  MMQ_X_Q3_K_PASCAL 64
+#define  MMQ_Y_Q3_K_PASCAL 64
+#define NWARPS_Q3_K_PASCAL 8
+
+template <bool need_check> static __global__ void mul_mat_q3_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= CC_TURING
+    const int mmq_x  =  MMQ_X_Q3_K_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q3_K_AMPERE;
+    const int nwarps = NWARPS_Q3_K_AMPERE;
+
+    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+        load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int mmq_x  =  MMQ_X_Q3_K_PASCAL;
+    const int mmq_y  =  MMQ_Y_Q3_K_PASCAL;
+    const int nwarps = NWARPS_Q3_K_PASCAL;
+
+    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+        load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
+    (void) vec_dot_q3_K_q8_1_mul_mat;
+    assert(false);
+#endif // __CUDA_ARCH__ >= CC_TURING
+}
+
+#define  MMQ_X_Q4_K_AMPERE 64
+#define  MMQ_Y_Q4_K_AMPERE 128
+#define NWARPS_Q4_K_AMPERE 4
+#define  MMQ_X_Q4_K_PASCAL 32
+#define  MMQ_Y_Q4_K_PASCAL 64
+#define NWARPS_Q4_K_PASCAL 8
+
+template <bool need_check> static __global__ void mul_mat_q4_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= CC_TURING
+    const int mmq_x  =  MMQ_X_Q4_K_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q4_K_AMPERE;
+    const int nwarps = NWARPS_Q4_K_AMPERE;
+
+    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+        load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int mmq_x  =  MMQ_X_Q4_K_PASCAL;
+    const int mmq_y  =  MMQ_Y_Q4_K_PASCAL;
+    const int nwarps = NWARPS_Q4_K_PASCAL;
+
+    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+        load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
+    (void) vec_dot_q4_K_q8_1_mul_mat;
+    assert(false);
+#endif // __CUDA_ARCH__ >= CC_TURING
+}
+
+#define  MMQ_X_Q5_K_AMPERE 64
+#define  MMQ_Y_Q5_K_AMPERE 128
+#define NWARPS_Q5_K_AMPERE 4
+#define  MMQ_X_Q5_K_PASCAL 64
+#define  MMQ_Y_Q5_K_PASCAL 64
+#define NWARPS_Q5_K_PASCAL 8
+
+template <bool need_check> static __global__ void mul_mat_q5_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= CC_TURING
+    const int mmq_x  =  MMQ_X_Q5_K_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q5_K_AMPERE;
+    const int nwarps = NWARPS_Q5_K_AMPERE;
+
+    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+        load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int mmq_x  =  MMQ_X_Q5_K_PASCAL;
+    const int mmq_y  =  MMQ_Y_Q5_K_PASCAL;
+    const int nwarps = NWARPS_Q5_K_PASCAL;
+
+    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+        load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
+    (void) vec_dot_q5_K_q8_1_mul_mat;
+    assert(false);
+#endif // __CUDA_ARCH__ >= CC_TURING
+}
+
+#define  MMQ_X_Q6_K_AMPERE 64
+#define  MMQ_Y_Q6_K_AMPERE 64
+#define NWARPS_Q6_K_AMPERE 4
+#define  MMQ_X_Q6_K_PASCAL 32
+#define  MMQ_Y_Q6_K_PASCAL 64
+#define NWARPS_Q6_K_PASCAL 8
+
+template <bool need_check> static __global__ void mul_mat_q6_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= CC_TURING
+    const int mmq_x  =  MMQ_X_Q6_K_AMPERE;
+    const int mmq_y  =  MMQ_Y_Q6_K_AMPERE;
+    const int nwarps = NWARPS_Q6_K_AMPERE;
+
+    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+        load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+
+#elif __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int mmq_x  =  MMQ_X_Q6_K_PASCAL;
+    const int mmq_y  =  MMQ_Y_Q6_K_PASCAL;
+    const int nwarps = NWARPS_Q6_K_PASCAL;
+
+    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+        load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
+    (void) vec_dot_q6_K_q8_1_mul_mat;
+    assert(false);
+#endif // __CUDA_ARCH__ >= CC_TURING
+}
+
 template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
 static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
@@ -3942,48 +4301,32 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
 
+    int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_TURING) {
-        const int mmq_x  = 64;
-        const int mmq_y  = 128;
-        const int nwarps = 4;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
-                load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
-                load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        mmq_x  =  MMQ_X_Q4_0_AMPERE;
+        mmq_y  =  MMQ_Y_Q4_0_AMPERE;
+        nwarps = NWARPS_Q4_0_AMPERE;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        mmq_x  =  MMQ_X_Q4_0_PASCAL;
+        mmq_y  =  MMQ_Y_Q4_0_PASCAL;
+        nwarps = NWARPS_Q4_0_PASCAL;
     } else {
-        const int mmq_x  = 64;
-        const int mmq_y  = 64;
-        const int nwarps = 4;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
-                load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
-                load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
 
@@ -3995,49 +4338,32 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
 
+    int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_TURING) {
-        const int mmq_x  = 64;
-        const int mmq_y  = 128;
-        const int nwarps = 4;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
-                load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
-                load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        mmq_x  =  MMQ_X_Q4_1_AMPERE;
+        mmq_y  =  MMQ_Y_Q4_1_AMPERE;
+        nwarps = NWARPS_Q4_1_AMPERE;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        mmq_x  =  MMQ_X_Q4_1_PASCAL;
+        mmq_y  =  MMQ_Y_Q4_1_PASCAL;
+        nwarps = NWARPS_Q4_1_PASCAL;
     } else {
-        const int mmq_x  = 64;
-        const int mmq_y  = 64;
-        const int nwarps = 8;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
-                load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
-                load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        GGML_ASSERT(false);
+    }
 
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
 
@@ -4049,48 +4375,32 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
 
+    int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_TURING) {
-        const int mmq_x  = 128;
-        const int mmq_y  = 64;
-        const int nwarps = 4;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
-                load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
-                load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        mmq_x  =  MMQ_X_Q5_0_AMPERE;
+        mmq_y  =  MMQ_Y_Q5_0_AMPERE;
+        nwarps = NWARPS_Q5_0_AMPERE;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        mmq_x  =  MMQ_X_Q5_0_PASCAL;
+        mmq_y  =  MMQ_Y_Q5_0_PASCAL;
+        nwarps = NWARPS_Q5_0_PASCAL;
     } else {
-        const int mmq_x  = 64;
-        const int mmq_y  = 64;
-        const int nwarps = 8;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
-                load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
-                load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
 
@@ -4102,48 +4412,32 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
 
+    int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_TURING) {
-        const int mmq_x  = 128;
-        const int mmq_y  = 64;
-        const int nwarps = 8;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
-                load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
-                load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        mmq_x  =  MMQ_X_Q5_1_AMPERE;
+        mmq_y  =  MMQ_Y_Q5_1_AMPERE;
+        nwarps = NWARPS_Q5_1_AMPERE;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        mmq_x  =  MMQ_X_Q5_1_PASCAL;
+        mmq_y  =  MMQ_Y_Q5_1_PASCAL;
+        nwarps = NWARPS_Q5_1_PASCAL;
     } else {
-        const int mmq_x  = 64;
-        const int mmq_y  = 64;
-        const int nwarps = 8;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
-                load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
-                load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
 
@@ -4155,48 +4449,32 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
 
+    int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_TURING) {
-        const int mmq_x  = 128;
-        const int mmq_y  = 64;
-        const int nwarps = 4;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
-                load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
-                load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        mmq_x  =  MMQ_X_Q8_0_AMPERE;
+        mmq_y  =  MMQ_Y_Q8_0_AMPERE;
+        nwarps = NWARPS_Q8_0_AMPERE;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        mmq_x  =  MMQ_X_Q8_0_PASCAL;
+        mmq_y  =  MMQ_Y_Q8_0_PASCAL;
+        nwarps = NWARPS_Q8_0_PASCAL;
     } else {
-        const int mmq_x  = 64;
-        const int mmq_y  = 64;
-        const int nwarps = 8;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
-                load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
-                load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
 
@@ -4208,48 +4486,32 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
 
+    int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_TURING) {
-        const int mmq_x  = 64;
-        const int mmq_y  = 128;
-        const int nwarps = 4;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
-                load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
-                load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        mmq_x  =  MMQ_X_Q2_K_AMPERE;
+        mmq_y  =  MMQ_Y_Q2_K_AMPERE;
+        nwarps = NWARPS_Q2_K_AMPERE;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        mmq_x  =  MMQ_X_Q2_K_PASCAL;
+        mmq_y  =  MMQ_Y_Q2_K_PASCAL;
+        nwarps = NWARPS_Q2_K_PASCAL;
     } else {
-        const int mmq_x  = 64;
-        const int mmq_y  = 64;
-        const int nwarps = 8;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
-                load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
-                load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
 
@@ -4261,48 +4523,32 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
 
+    int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_TURING) {
-        const int mmq_x  = 128;
-        const int mmq_y  = 128;
-        const int nwarps = 4;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
-                load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
-                load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        mmq_x  =  MMQ_X_Q3_K_AMPERE;
+        mmq_y  =  MMQ_Y_Q3_K_AMPERE;
+        nwarps = NWARPS_Q3_K_AMPERE;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        mmq_x  =  MMQ_X_Q3_K_PASCAL;
+        mmq_y  =  MMQ_Y_Q3_K_PASCAL;
+        nwarps = NWARPS_Q3_K_PASCAL;
     } else {
-        const int mmq_x  = 64;
-        const int mmq_y  = 64;
-        const int nwarps = 8;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
-                load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
-                load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
 
@@ -4314,48 +4560,32 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
 
+    int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_TURING) {
-        const int mmq_x  = 64;
-        const int mmq_y  = 128;
-        const int nwarps = 4;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
-                load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
-                load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        mmq_x  =  MMQ_X_Q4_K_AMPERE;
+        mmq_y  =  MMQ_Y_Q4_K_AMPERE;
+        nwarps = NWARPS_Q4_K_AMPERE;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        mmq_x  =  MMQ_X_Q4_K_PASCAL;
+        mmq_y  =  MMQ_Y_Q4_K_PASCAL;
+        nwarps = NWARPS_Q4_K_PASCAL;
     } else {
-        const int mmq_x  = 32;
-        const int mmq_y  = 64;
-        const int nwarps = 8;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
-                load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
-                load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
 
@@ -4367,48 +4597,32 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
 
+    int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_TURING) {
-        const int mmq_x  = 64;
-        const int mmq_y  = 128;
-        const int nwarps = 4;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
-                load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
-                load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        mmq_x  =  MMQ_X_Q5_K_AMPERE;
+        mmq_y  =  MMQ_Y_Q5_K_AMPERE;
+        nwarps = NWARPS_Q5_K_AMPERE;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        mmq_x  =  MMQ_X_Q5_K_PASCAL;
+        mmq_y  =  MMQ_Y_Q5_K_PASCAL;
+        nwarps = NWARPS_Q5_K_PASCAL;
     } else {
-        const int mmq_x  = 64;
-        const int mmq_y  = 64;
-        const int nwarps = 8;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
-                load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
-                load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }
 
@@ -4420,48 +4634,32 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
 
+    int mmq_x, mmq_y, nwarps;
     if (compute_capability >= CC_TURING) {
-        const int mmq_x  = 64;
-        const int mmq_y  = 64;
-        const int nwarps = 4;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
-                load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
-                load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        mmq_x  =  MMQ_X_Q6_K_AMPERE;
+        mmq_y  =  MMQ_Y_Q6_K_AMPERE;
+        nwarps = NWARPS_Q6_K_AMPERE;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        mmq_x  =  MMQ_X_Q6_K_PASCAL;
+        mmq_y  =  MMQ_Y_Q6_K_PASCAL;
+        nwarps = NWARPS_Q6_K_PASCAL;
     } else {
-        const int mmq_x  = 32;
-        const int mmq_y  = 64;
-        const int nwarps = 8;
-
-        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-        const dim3 block_nums(block_num_x, block_num_y, 1);
-        const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-        if (nrows_x % mmq_y == 0) {
-            const bool need_check = false;
-            mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
-                load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        } else {
-            const bool need_check = true;
-            mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
-                load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-        }
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
 }