]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: deduplicate mmq code (llama/7397)
authorJohannes Gäßler <redacted>
Tue, 21 May 2024 14:02:12 +0000 (16:02 +0200)
committerGeorgi Gerganov <redacted>
Sun, 16 Jun 2024 15:19:48 +0000 (18:19 +0300)
ggml-cuda/mmq.cu

index 7948f1b1237fa81e04a1ecaaaebd5b1c2c31b47e..5b540d375031b7f969729561c8a612e286f217e5 100644 (file)
@@ -9,6 +9,135 @@ typedef float (*vec_dot_q_mul_mat_cuda_t)(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
 typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
+typedef void (mul_mat_q_t)(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst);
+
+struct mmq_arch_config_t {
+    int x;
+    int y;
+    int nwarps;
+};
+
+struct mmq_config_t {
+    mmq_arch_config_t rdna2;
+    mmq_arch_config_t rdna1;
+    mmq_arch_config_t ampere;
+    mmq_arch_config_t pascal;
+};
+
+constexpr mmq_config_t MMQ_CONFIG_Q4_0 = {
+//        x    y  nwarps
+        { 64, 128, 8},
+        { 64,  64, 8},
+#ifdef CUDA_USE_TENSOR_CORES
+        {  4,  32, 4},
+#else
+        { 64, 128, 4},
+#endif // CUDA_USE_TENSOR_CORES
+        { 64,  64, 8},
+};
+constexpr mmq_config_t MMQ_CONFIG_Q4_1 = {
+//        x    y  nwarps
+        { 64, 128, 8},
+        { 64,  64, 8},
+#ifdef CUDA_USE_TENSOR_CORES
+        {  4,  32, 4},
+#else
+        { 64, 128, 4},
+#endif // CUDA_USE_TENSOR_CORES
+        { 64,  64, 8},
+};
+constexpr mmq_config_t MMQ_CONFIG_Q5_0 = {
+//        x    y  nwarps
+        { 64, 128, 8},
+        { 64,  64, 8},
+#ifdef CUDA_USE_TENSOR_CORES
+        {  4,  32, 4},
+#else
+        {128,  64, 4},
+#endif // CUDA_USE_TENSOR_CORES
+        { 64,  64, 8},
+};
+constexpr mmq_config_t MMQ_CONFIG_Q5_1 = {
+//        x    y  nwarps
+        { 64, 128, 8},
+        { 64,  64, 8},
+#ifdef CUDA_USE_TENSOR_CORES
+        {  4,  32, 4},
+#else
+        {128,  64, 4},
+#endif // CUDA_USE_TENSOR_CORES
+        { 64,  64, 8},
+};
+constexpr mmq_config_t MMQ_CONFIG_Q8_0 = {
+//        x    y  nwarps
+        { 64, 128, 8},
+        { 64,  64, 8},
+#ifdef CUDA_USE_TENSOR_CORES
+        {  4,  32, 4},
+#else
+        {128,  64, 4},
+#endif // CUDA_USE_TENSOR_CORES
+        { 64,  64, 8},
+};
+constexpr mmq_config_t MMQ_CONFIG_Q2_K = {
+//        x    y  nwarps
+        { 64, 128, 8},
+        {128,  32, 8},
+#ifdef CUDA_USE_TENSOR_CORES
+        {  4,  32, 4},
+#else
+        { 64, 128, 4},
+#endif // CUDA_USE_TENSOR_CORES
+        { 64,  64, 8},
+};
+constexpr mmq_config_t MMQ_CONFIG_Q3_K = {
+//        x    y  nwarps
+        {128,  64, 8},
+        { 32, 128, 8},
+#ifdef CUDA_USE_TENSOR_CORES
+        {  4,  32, 4},
+#else
+        {128, 128, 4},
+#endif // CUDA_USE_TENSOR_CORES
+        { 64,  64, 8},
+};
+constexpr mmq_config_t MMQ_CONFIG_Q4_K = {
+//        x    y  nwarps
+        { 64, 128, 8},
+        { 32,  64, 8},
+#ifdef CUDA_USE_TENSOR_CORES
+        {  4,  32, 4},
+#else
+        { 64, 128, 4},
+#endif // CUDA_USE_TENSOR_CORES
+        { 64,  64, 8},
+};
+constexpr mmq_config_t MMQ_CONFIG_Q5_K = {
+//        x    y  nwarps
+        { 64, 128, 8},
+        { 32,  64, 8},
+#ifdef CUDA_USE_TENSOR_CORES
+        {  4,  32, 4},
+#else
+        { 64, 128, 4},
+#endif // CUDA_USE_TENSOR_CORES
+        { 64,  64, 8},
+};
+constexpr mmq_config_t MMQ_CONFIG_Q6_K = {
+//        x    y  nwarps
+        { 64, 128, 8},
+        { 32,  64, 8},
+#ifdef CUDA_USE_TENSOR_CORES
+        {  4,  32, 4},
+#else
+        { 64,  64, 4},
+#endif // CUDA_USE_TENSOR_CORES
+        { 64,  64, 8},
+};
+
+// ------------------------------------------------------------
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
     GGML_UNUSED(x_qh);
@@ -943,25 +1072,6 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
     return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
 }
 
-#define  MMQ_X_Q4_0_RDNA2  64
-#define  MMQ_Y_Q4_0_RDNA2  128
-#define NWARPS_Q4_0_RDNA2  8
-#define  MMQ_X_Q4_0_RDNA1  64
-#define  MMQ_Y_Q4_0_RDNA1  64
-#define NWARPS_Q4_0_RDNA1  8
-#if defined(CUDA_USE_TENSOR_CORES)
-#define  MMQ_X_Q4_0_AMPERE 4
-#define  MMQ_Y_Q4_0_AMPERE 32
-#define NWARPS_Q4_0_AMPERE 4
-#else
-#define  MMQ_X_Q4_0_AMPERE 64
-#define  MMQ_Y_Q4_0_AMPERE 128
-#define NWARPS_Q4_0_AMPERE 4
-#endif
-#define  MMQ_X_Q4_0_PASCAL 64
-#define  MMQ_Y_Q4_0_PASCAL 64
-#define NWARPS_Q4_0_PASCAL 8
-
 template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
               allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
 static __device__ __forceinline__ void mul_mat_q(
@@ -1072,1107 +1182,265 @@ static __device__ __forceinline__ void mul_mat_q(
     }
 }
 
-template <bool need_check> static __global__ void
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-    mul_mat_q4_0(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+static constexpr __device__ mmq_arch_config_t get_arch_config_device(mmq_config_t mmq_config) {
 
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
 #if defined(RDNA3) || defined(RDNA2)
-    const int mmq_x  =  MMQ_X_Q4_0_RDNA2;
-    const int mmq_y  =  MMQ_Y_Q4_0_RDNA2;
-    const int nwarps = NWARPS_Q4_0_RDNA2;
+    return mmq_config.rdna2;
 #else
-    const int mmq_x  =  MMQ_X_Q4_0_RDNA1;
-    const int mmq_y  =  MMQ_Y_Q4_0_RDNA1;
-    const int nwarps = NWARPS_Q4_0_RDNA1;
+    return mmq_config.rdna1;
 #endif // defined(RDNA3) || defined(RDNA2)
 
-    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
-        load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#else
 
-#elif __CUDA_ARCH__ >= CC_VOLTA
-    const int mmq_x  =  MMQ_X_Q4_0_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q4_0_AMPERE;
-    const int nwarps = NWARPS_Q4_0_AMPERE;
+#if __CUDA_ARCH__ >= CC_VOLTA
+    return mmq_config.ampere;
+#else
+    return mmq_config.pascal;
+#endif // __CUDA_ARCH__ >= CC_VOLTA
 
-    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
-        load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
 
-#elif __CUDA_ARCH__ >= MIN_CC_DP4A
-    const int mmq_x  =  MMQ_X_Q4_0_PASCAL;
-    const int mmq_y  =  MMQ_Y_Q4_0_PASCAL;
-    const int nwarps = NWARPS_Q4_0_PASCAL;
+template <bool need_check> static __global__ void
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_0.rdna2.nwarps, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    mul_mat_q4_0(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_0);
 
-    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
-        load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q4_0<arch_config.y>,
+        load_tiles_q4_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     GGML_UNUSED(vec_dot_q4_0_q8_1_mul_mat);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-#define  MMQ_X_Q4_1_RDNA2  64
-#define  MMQ_Y_Q4_1_RDNA2  128
-#define NWARPS_Q4_1_RDNA2  8
-#define  MMQ_X_Q4_1_RDNA1  64
-#define  MMQ_Y_Q4_1_RDNA1  64
-#define NWARPS_Q4_1_RDNA1  8
-#if defined(CUDA_USE_TENSOR_CORES)
-#define  MMQ_X_Q4_1_AMPERE 4
-#define  MMQ_Y_Q4_1_AMPERE 32
-#define NWARPS_Q4_1_AMPERE 4
-#else
-#define  MMQ_X_Q4_1_AMPERE 64
-#define  MMQ_Y_Q4_1_AMPERE 128
-#define NWARPS_Q4_1_AMPERE 4
-#endif
-#define  MMQ_X_Q4_1_PASCAL 64
-#define  MMQ_Y_Q4_1_PASCAL 64
-#define NWARPS_Q4_1_PASCAL 8
-
 template <bool need_check> static __global__ void
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_1.rdna2.nwarps, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
 #elif __CUDA_ARCH__ < CC_VOLTA
-    __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_1.pascal.nwarps, 2)
 #endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q4_1(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    const int mmq_x  =  MMQ_X_Q4_1_RDNA2;
-    const int mmq_y  =  MMQ_Y_Q4_1_RDNA2;
-    const int nwarps = NWARPS_Q4_1_RDNA2;
-#else
-    const int mmq_x  =  MMQ_X_Q4_1_RDNA1;
-    const int mmq_y  =  MMQ_Y_Q4_1_RDNA1;
-    const int nwarps = NWARPS_Q4_1_RDNA1;
-#endif // defined(RDNA3) || defined(RDNA2)
-
-    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
-        load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= CC_VOLTA
-    const int mmq_x  =  MMQ_X_Q4_1_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q4_1_AMPERE;
-    const int nwarps = NWARPS_Q4_1_AMPERE;
-
-    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
-        load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= MIN_CC_DP4A
-    const int mmq_x  =  MMQ_X_Q4_1_PASCAL;
-    const int mmq_y  =  MMQ_Y_Q4_1_PASCAL;
-    const int nwarps = NWARPS_Q4_1_PASCAL;
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_1);
 
-    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
-        load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q4_1<arch_config.y>,
+        load_tiles_q4_1<arch_config.y, arch_config.nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     GGML_UNUSED(vec_dot_q4_1_q8_1_mul_mat);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-#define  MMQ_X_Q5_0_RDNA2  64
-#define  MMQ_Y_Q5_0_RDNA2  128
-#define NWARPS_Q5_0_RDNA2  8
-#define  MMQ_X_Q5_0_RDNA1  64
-#define  MMQ_Y_Q5_0_RDNA1  64
-#define NWARPS_Q5_0_RDNA1  8
-#if defined(CUDA_USE_TENSOR_CORES)
-#define  MMQ_X_Q5_0_AMPERE 4
-#define  MMQ_Y_Q5_0_AMPERE 32
-#define NWARPS_Q5_0_AMPERE 4
-#else
-#define  MMQ_X_Q5_0_AMPERE 128
-#define  MMQ_Y_Q5_0_AMPERE 64
-#define NWARPS_Q5_0_AMPERE 4
-#endif
-#define  MMQ_X_Q5_0_PASCAL 64
-#define  MMQ_Y_Q5_0_PASCAL 64
-#define NWARPS_Q5_0_PASCAL 8
-
 template <bool need_check> static __global__ void
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_0.rdna2.nwarps, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
     mul_mat_q5_0(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    const int mmq_x  =  MMQ_X_Q5_0_RDNA2;
-    const int mmq_y  =  MMQ_Y_Q5_0_RDNA2;
-    const int nwarps = NWARPS_Q5_0_RDNA2;
-#else
-    const int mmq_x  =  MMQ_X_Q5_0_RDNA1;
-    const int mmq_y  =  MMQ_Y_Q5_0_RDNA1;
-    const int nwarps = NWARPS_Q5_0_RDNA1;
-#endif // defined(RDNA3) || defined(RDNA2)
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_0);
 
-    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
-        load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= CC_VOLTA
-    const int mmq_x  =  MMQ_X_Q5_0_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q5_0_AMPERE;
-    const int nwarps = NWARPS_Q5_0_AMPERE;
-
-    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
-        load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= MIN_CC_DP4A
-    const int mmq_x  =  MMQ_X_Q5_0_PASCAL;
-    const int mmq_y  =  MMQ_Y_Q5_0_PASCAL;
-    const int nwarps = NWARPS_Q5_0_PASCAL;
-
-    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
-        load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q5_0<arch_config.y>,
+        load_tiles_q5_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     GGML_UNUSED(vec_dot_q5_0_q8_1_mul_mat);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-#define  MMQ_X_Q5_1_RDNA2  64
-#define  MMQ_Y_Q5_1_RDNA2  128
-#define NWARPS_Q5_1_RDNA2  8
-#define  MMQ_X_Q5_1_RDNA1  64
-#define  MMQ_Y_Q5_1_RDNA1  64
-#define NWARPS_Q5_1_RDNA1  8
-#if defined(CUDA_USE_TENSOR_CORES)
-#define  MMQ_X_Q5_1_AMPERE 4
-#define  MMQ_Y_Q5_1_AMPERE 32
-#define NWARPS_Q5_1_AMPERE 4
-#else
-#define  MMQ_X_Q5_1_AMPERE 128
-#define  MMQ_Y_Q5_1_AMPERE 64
-#define NWARPS_Q5_1_AMPERE 4
-#endif
-#define  MMQ_X_Q5_1_PASCAL 64
-#define  MMQ_Y_Q5_1_PASCAL 64
-#define NWARPS_Q5_1_PASCAL 8
-
 template <bool need_check> static __global__ void
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_1.rdna2.nwarps, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 mul_mat_q5_1(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    const int mmq_x  =  MMQ_X_Q5_1_RDNA2;
-    const int mmq_y  =  MMQ_Y_Q5_1_RDNA2;
-    const int nwarps = NWARPS_Q5_1_RDNA2;
-#else
-    const int mmq_x  =  MMQ_X_Q5_1_RDNA1;
-    const int mmq_y  =  MMQ_Y_Q5_1_RDNA1;
-    const int nwarps = NWARPS_Q5_1_RDNA1;
-#endif // defined(RDNA3) || defined(RDNA2)
-
-    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
-        load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= CC_VOLTA
-    const int mmq_x  =  MMQ_X_Q5_1_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q5_1_AMPERE;
-    const int nwarps = NWARPS_Q5_1_AMPERE;
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_1);
 
-    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
-        load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= MIN_CC_DP4A
-    const int mmq_x  =  MMQ_X_Q5_1_PASCAL;
-    const int mmq_y  =  MMQ_Y_Q5_1_PASCAL;
-    const int nwarps = NWARPS_Q5_1_PASCAL;
-
-    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
-        load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q5_1<arch_config.y>,
+        load_tiles_q5_1<arch_config.y, arch_config.nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     GGML_UNUSED(vec_dot_q5_1_q8_1_mul_mat);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-#define  MMQ_X_Q8_0_RDNA2  64
-#define  MMQ_Y_Q8_0_RDNA2  128
-#define NWARPS_Q8_0_RDNA2  8
-#define  MMQ_X_Q8_0_RDNA1  64
-#define  MMQ_Y_Q8_0_RDNA1  64
-#define NWARPS_Q8_0_RDNA1  8
-#if defined(CUDA_USE_TENSOR_CORES)
-#define  MMQ_X_Q8_0_AMPERE 4
-#define  MMQ_Y_Q8_0_AMPERE 32
-#define NWARPS_Q8_0_AMPERE 4
-#else
-#define  MMQ_X_Q8_0_AMPERE 128
-#define  MMQ_Y_Q8_0_AMPERE 64
-#define NWARPS_Q8_0_AMPERE 4
-#endif
-#define  MMQ_X_Q8_0_PASCAL 64
-#define  MMQ_Y_Q8_0_PASCAL 64
-#define NWARPS_Q8_0_PASCAL 8
-
 template <bool need_check> static __global__ void
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q8_0.rdna2.nwarps, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
     mul_mat_q8_0(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    const int mmq_x  =  MMQ_X_Q8_0_RDNA2;
-    const int mmq_y  =  MMQ_Y_Q8_0_RDNA2;
-    const int nwarps = NWARPS_Q8_0_RDNA2;
-#else
-    const int mmq_x  =  MMQ_X_Q8_0_RDNA1;
-    const int mmq_y  =  MMQ_Y_Q8_0_RDNA1;
-    const int nwarps = NWARPS_Q8_0_RDNA1;
-#endif // defined(RDNA3) || defined(RDNA2)
-
-    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
-        load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= CC_VOLTA
-    const int mmq_x  =  MMQ_X_Q8_0_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q8_0_AMPERE;
-    const int nwarps = NWARPS_Q8_0_AMPERE;
-
-    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
-        load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= MIN_CC_DP4A
-    const int mmq_x  =  MMQ_X_Q8_0_PASCAL;
-    const int mmq_y  =  MMQ_Y_Q8_0_PASCAL;
-    const int nwarps = NWARPS_Q8_0_PASCAL;
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q8_0);
 
-    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
-        load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q8_0<arch_config.y>,
+        load_tiles_q8_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     GGML_UNUSED(vec_dot_q8_0_q8_1_mul_mat);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-#define  MMQ_X_Q2_K_RDNA2  64
-#define  MMQ_Y_Q2_K_RDNA2  128
-#define NWARPS_Q2_K_RDNA2  8
-#define  MMQ_X_Q2_K_RDNA1  128
-#define  MMQ_Y_Q2_K_RDNA1  32
-#define NWARPS_Q2_K_RDNA1  8
-#if defined(CUDA_USE_TENSOR_CORES)
-#define  MMQ_X_Q2_K_AMPERE 4
-#define  MMQ_Y_Q2_K_AMPERE 32
-#define NWARPS_Q2_K_AMPERE 4
-#else
-#define  MMQ_X_Q2_K_AMPERE 64
-#define  MMQ_Y_Q2_K_AMPERE 128
-#define NWARPS_Q2_K_AMPERE 4
-#endif
-#define  MMQ_X_Q2_K_PASCAL 64
-#define  MMQ_Y_Q2_K_PASCAL 64
-#define NWARPS_Q2_K_PASCAL 8
-
 template <bool need_check> static __global__ void
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q2_K.rdna2.nwarps, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 mul_mat_q2_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    const int mmq_x  =  MMQ_X_Q2_K_RDNA2;
-    const int mmq_y  =  MMQ_Y_Q2_K_RDNA2;
-    const int nwarps = NWARPS_Q2_K_RDNA2;
-#else
-    const int mmq_x  =  MMQ_X_Q2_K_RDNA1;
-    const int mmq_y  =  MMQ_Y_Q2_K_RDNA1;
-    const int nwarps = NWARPS_Q2_K_RDNA1;
-#endif // defined(RDNA3) || defined(RDNA2)
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q2_K);
 
-    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
-        load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= CC_VOLTA
-    const int mmq_x  =  MMQ_X_Q2_K_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q2_K_AMPERE;
-    const int nwarps = NWARPS_Q2_K_AMPERE;
-
-    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
-        load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= MIN_CC_DP4A
-    const int mmq_x  =  MMQ_X_Q2_K_PASCAL;
-    const int mmq_y  =  MMQ_Y_Q2_K_PASCAL;
-    const int nwarps = NWARPS_Q2_K_PASCAL;
-
-    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
-        load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q2_K<arch_config.y>,
+        load_tiles_q2_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     GGML_UNUSED(vec_dot_q2_K_q8_1_mul_mat);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-#define  MMQ_X_Q3_K_RDNA2  128
-#define  MMQ_Y_Q3_K_RDNA2  64
-#define NWARPS_Q3_K_RDNA2  8
-#define  MMQ_X_Q3_K_RDNA1  32
-#define  MMQ_Y_Q3_K_RDNA1  128
-#define NWARPS_Q3_K_RDNA1  8
-#if defined(CUDA_USE_TENSOR_CORES)
-#define  MMQ_X_Q3_K_AMPERE 4
-#define  MMQ_Y_Q3_K_AMPERE 32
-#define NWARPS_Q3_K_AMPERE 4
-#else
-#define  MMQ_X_Q3_K_AMPERE 128
-#define  MMQ_Y_Q3_K_AMPERE 128
-#define NWARPS_Q3_K_AMPERE 4
-#endif
-#define  MMQ_X_Q3_K_PASCAL 64
-#define  MMQ_Y_Q3_K_PASCAL 64
-#define NWARPS_Q3_K_PASCAL 8
-
 template <bool need_check> static __global__ void
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q3_K.rdna2.nwarps, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
 #elif __CUDA_ARCH__ < CC_VOLTA
-    __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q3_K.pascal.nwarps, 2)
 #endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q3_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    const int mmq_x  =  MMQ_X_Q3_K_RDNA2;
-    const int mmq_y  =  MMQ_Y_Q3_K_RDNA2;
-    const int nwarps = NWARPS_Q3_K_RDNA2;
-#else
-    const int mmq_x  =  MMQ_X_Q3_K_RDNA1;
-    const int mmq_y  =  MMQ_Y_Q3_K_RDNA1;
-    const int nwarps = NWARPS_Q3_K_RDNA1;
-#endif // defined(RDNA3) || defined(RDNA2)
-
-    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
-        load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= CC_VOLTA
-    const int mmq_x  =  MMQ_X_Q3_K_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q3_K_AMPERE;
-    const int nwarps = NWARPS_Q3_K_AMPERE;
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q3_K);
 
-    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
-        load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= MIN_CC_DP4A
-    const int mmq_x  =  MMQ_X_Q3_K_PASCAL;
-    const int mmq_y  =  MMQ_Y_Q3_K_PASCAL;
-    const int nwarps = NWARPS_Q3_K_PASCAL;
-
-    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
-        load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q3_K<arch_config.y>,
+        load_tiles_q3_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     GGML_UNUSED(vec_dot_q3_K_q8_1_mul_mat);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-#define  MMQ_X_Q4_K_RDNA2  64
-#define  MMQ_Y_Q4_K_RDNA2  128
-#define NWARPS_Q4_K_RDNA2  8
-#define  MMQ_X_Q4_K_RDNA1  32
-#define  MMQ_Y_Q4_K_RDNA1  64
-#define NWARPS_Q4_K_RDNA1  8
-#if defined(CUDA_USE_TENSOR_CORES)
-#define  MMQ_X_Q4_K_AMPERE 4
-#define  MMQ_Y_Q4_K_AMPERE 32
-#define NWARPS_Q4_K_AMPERE 4
-#else
-#define  MMQ_X_Q4_K_AMPERE 64
-#define  MMQ_Y_Q4_K_AMPERE 128
-#define NWARPS_Q4_K_AMPERE 4
-#endif
-#define  MMQ_X_Q4_K_PASCAL 64
-#define  MMQ_Y_Q4_K_PASCAL 64
-#define NWARPS_Q4_K_PASCAL 8
-
 template <bool need_check> static __global__ void
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.rdna2.nwarps, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
 #elif __CUDA_ARCH__ < CC_VOLTA
-    __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.pascal.nwarps, 2)
 #endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q4_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    const int mmq_x  =  MMQ_X_Q4_K_RDNA2;
-    const int mmq_y  =  MMQ_Y_Q4_K_RDNA2;
-    const int nwarps = NWARPS_Q4_K_RDNA2;
-#else
-    const int mmq_x  =  MMQ_X_Q4_K_RDNA1;
-    const int mmq_y  =  MMQ_Y_Q4_K_RDNA1;
-    const int nwarps = NWARPS_Q4_K_RDNA1;
-#endif // defined(RDNA3) || defined(RDNA2)
-
-    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
-        load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= CC_VOLTA
-    const int mmq_x  =  MMQ_X_Q4_K_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q4_K_AMPERE;
-    const int nwarps = NWARPS_Q4_K_AMPERE;
-
-    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
-        load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= MIN_CC_DP4A
-    const int mmq_x  =  MMQ_X_Q4_K_PASCAL;
-    const int mmq_y  =  MMQ_Y_Q4_K_PASCAL;
-    const int nwarps = NWARPS_Q4_K_PASCAL;
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_K);
 
-    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
-        load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q4_K<arch_config.y>,
+        load_tiles_q4_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     GGML_UNUSED(vec_dot_q4_K_q8_1_mul_mat);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-#define  MMQ_X_Q5_K_RDNA2  64
-#define  MMQ_Y_Q5_K_RDNA2  128
-#define NWARPS_Q5_K_RDNA2  8
-#define  MMQ_X_Q5_K_RDNA1  32
-#define  MMQ_Y_Q5_K_RDNA1  64
-#define NWARPS_Q5_K_RDNA1  8
-#if defined(CUDA_USE_TENSOR_CORES)
-#define  MMQ_X_Q5_K_AMPERE 4
-#define  MMQ_Y_Q5_K_AMPERE 32
-#define NWARPS_Q5_K_AMPERE 4
-#else
-#define  MMQ_X_Q5_K_AMPERE 64
-#define  MMQ_Y_Q5_K_AMPERE 128
-#define NWARPS_Q5_K_AMPERE 4
-#endif
-#define  MMQ_X_Q5_K_PASCAL 64
-#define  MMQ_Y_Q5_K_PASCAL 64
-#define NWARPS_Q5_K_PASCAL 8
-
 template <bool need_check> static __global__ void
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_K.rdna2.nwarps, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 mul_mat_q5_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    const int mmq_x  =  MMQ_X_Q5_K_RDNA2;
-    const int mmq_y  =  MMQ_Y_Q5_K_RDNA2;
-    const int nwarps = NWARPS_Q5_K_RDNA2;
-#else
-    const int mmq_x  =  MMQ_X_Q5_K_RDNA1;
-    const int mmq_y  =  MMQ_Y_Q5_K_RDNA1;
-    const int nwarps = NWARPS_Q5_K_RDNA1;
-#endif // defined(RDNA3) || defined(RDNA2)
-
-    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
-        load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= CC_VOLTA
-    const int mmq_x  =  MMQ_X_Q5_K_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q5_K_AMPERE;
-    const int nwarps = NWARPS_Q5_K_AMPERE;
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_K);
 
-    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
-        load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= MIN_CC_DP4A
-    const int mmq_x  =  MMQ_X_Q5_K_PASCAL;
-    const int mmq_y  =  MMQ_Y_Q5_K_PASCAL;
-    const int nwarps = NWARPS_Q5_K_PASCAL;
-
-    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
-        load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q5_K<arch_config.y>,
+        load_tiles_q5_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     GGML_UNUSED(vec_dot_q5_K_q8_1_mul_mat);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-#define  MMQ_X_Q6_K_RDNA2  64
-#define  MMQ_Y_Q6_K_RDNA2  128
-#define NWARPS_Q6_K_RDNA2  8
-#define  MMQ_X_Q6_K_RDNA1  32
-#define  MMQ_Y_Q6_K_RDNA1  64
-#define NWARPS_Q6_K_RDNA1  8
-#if defined(CUDA_USE_TENSOR_CORES)
-#define  MMQ_X_Q6_K_AMPERE 4
-#define  MMQ_Y_Q6_K_AMPERE 32
-#define NWARPS_Q6_K_AMPERE 4
-#else
-#define  MMQ_X_Q6_K_AMPERE 64
-#define  MMQ_Y_Q6_K_AMPERE 64
-#define NWARPS_Q6_K_AMPERE 4
-#endif
-#define  MMQ_X_Q6_K_PASCAL 64
-#define  MMQ_Y_Q6_K_PASCAL 64
-#define NWARPS_Q6_K_PASCAL 8
-
 template <bool need_check> static __global__ void
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q6_K.rdna2.nwarps, 2)
 #endif // defined(RDNA3) || defined(RDNA2)
 #elif __CUDA_ARCH__ < CC_VOLTA
-    __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
+    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.pascal.nwarps, 2)
 #endif // __CUDA_ARCH__ < CC_VOLTA
     mul_mat_q6_K(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    const int mmq_x  =  MMQ_X_Q6_K_RDNA2;
-    const int mmq_y  =  MMQ_Y_Q6_K_RDNA2;
-    const int nwarps = NWARPS_Q6_K_RDNA2;
-#else
-    const int mmq_x  =  MMQ_X_Q6_K_RDNA1;
-    const int mmq_y  =  MMQ_Y_Q6_K_RDNA1;
-    const int nwarps = NWARPS_Q6_K_RDNA1;
-#endif // defined(RDNA3) || defined(RDNA2)
-
-    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
-        load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= CC_VOLTA
-    const int mmq_x  =  MMQ_X_Q6_K_AMPERE;
-    const int mmq_y  =  MMQ_Y_Q6_K_AMPERE;
-    const int nwarps = NWARPS_Q6_K_AMPERE;
-
-    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
-        load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
-        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-
-#elif __CUDA_ARCH__ >= MIN_CC_DP4A
-    const int mmq_x  =  MMQ_X_Q6_K_PASCAL;
-    const int mmq_y  =  MMQ_Y_Q6_K_PASCAL;
-    const int nwarps = NWARPS_Q6_K_PASCAL;
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q6_K);
 
-    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
-        load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q6_K<arch_config.y>,
+        load_tiles_q6_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
         (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 #else
     GGML_UNUSED(vec_dot_q6_K_q8_1_mul_mat);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
-static void ggml_mul_mat_q4_0_q8_1_cuda(
-    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
-    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
-
-    int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_RDNA2) {
-        mmq_x  =  MMQ_X_Q4_0_RDNA2;
-        mmq_y  =  MMQ_Y_Q4_0_RDNA2;
-        nwarps = NWARPS_Q4_0_RDNA2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        mmq_x  =  MMQ_X_Q4_0_RDNA1;
-        mmq_y  =  MMQ_Y_Q4_0_RDNA1;
-        nwarps = NWARPS_Q4_0_RDNA1;
-    } else if (compute_capability >= CC_VOLTA) {
-        mmq_x  =  MMQ_X_Q4_0_AMPERE;
-        mmq_y  =  MMQ_Y_Q4_0_AMPERE;
-        nwarps = NWARPS_Q4_0_AMPERE;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        mmq_x  =  MMQ_X_Q4_0_PASCAL;
-        mmq_y  =  MMQ_Y_Q4_0_PASCAL;
-        nwarps = NWARPS_Q4_0_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    } else {
-        const bool need_check = true;
-        mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    }
-}
-
-static void ggml_mul_mat_q4_1_q8_1_cuda(
-    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
-    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
-
-    int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_RDNA2) {
-        mmq_x  =  MMQ_X_Q4_1_RDNA2;
-        mmq_y  =  MMQ_Y_Q4_1_RDNA2;
-        nwarps = NWARPS_Q4_1_RDNA2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        mmq_x  =  MMQ_X_Q4_1_RDNA1;
-        mmq_y  =  MMQ_Y_Q4_1_RDNA1;
-        nwarps = NWARPS_Q4_1_RDNA1;
-    } else if (compute_capability >= CC_VOLTA) {
-        mmq_x  =  MMQ_X_Q4_1_AMPERE;
-        mmq_y  =  MMQ_Y_Q4_1_AMPERE;
-        nwarps = NWARPS_Q4_1_AMPERE;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        mmq_x  =  MMQ_X_Q4_1_PASCAL;
-        mmq_y  =  MMQ_Y_Q4_1_PASCAL;
-        nwarps = NWARPS_Q4_1_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    } else {
-        const bool need_check = true;
-        mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    }
-}
-
-static void ggml_mul_mat_q5_0_q8_1_cuda(
-    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
-    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
-
-    int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_RDNA2) {
-        mmq_x  =  MMQ_X_Q5_0_RDNA2;
-        mmq_y  =  MMQ_Y_Q5_0_RDNA2;
-        nwarps = NWARPS_Q5_0_RDNA2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        mmq_x  =  MMQ_X_Q5_0_RDNA1;
-        mmq_y  =  MMQ_Y_Q5_0_RDNA1;
-        nwarps = NWARPS_Q5_0_RDNA1;
-    } else if (compute_capability >= CC_VOLTA) {
-        mmq_x  =  MMQ_X_Q5_0_AMPERE;
-        mmq_y  =  MMQ_Y_Q5_0_AMPERE;
-        nwarps = NWARPS_Q5_0_AMPERE;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        mmq_x  =  MMQ_X_Q5_0_PASCAL;
-        mmq_y  =  MMQ_Y_Q5_0_PASCAL;
-        nwarps = NWARPS_Q5_0_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    } else {
-        const bool need_check = true;
-        mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    }
-}
-
-static void ggml_mul_mat_q5_1_q8_1_cuda(
-    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
-    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
-
-    int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_RDNA2) {
-        mmq_x  =  MMQ_X_Q5_1_RDNA2;
-        mmq_y  =  MMQ_Y_Q5_1_RDNA2;
-        nwarps = NWARPS_Q5_1_RDNA2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        mmq_x  =  MMQ_X_Q5_1_RDNA1;
-        mmq_y  =  MMQ_Y_Q5_1_RDNA1;
-        nwarps = NWARPS_Q5_1_RDNA1;
-    } else if (compute_capability >= CC_VOLTA) {
-        mmq_x  =  MMQ_X_Q5_1_AMPERE;
-        mmq_y  =  MMQ_Y_Q5_1_AMPERE;
-        nwarps = NWARPS_Q5_1_AMPERE;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        mmq_x  =  MMQ_X_Q5_1_PASCAL;
-        mmq_y  =  MMQ_Y_Q5_1_PASCAL;
-        nwarps = NWARPS_Q5_1_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    } else {
-        const bool need_check = true;
-        mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    }
-}
-
-static void ggml_mul_mat_q8_0_q8_1_cuda(
-    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
-    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
-
-    int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_RDNA2) {
-        mmq_x  =  MMQ_X_Q8_0_RDNA2;
-        mmq_y  =  MMQ_Y_Q8_0_RDNA2;
-        nwarps = NWARPS_Q8_0_RDNA2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        mmq_x  =  MMQ_X_Q8_0_RDNA1;
-        mmq_y  =  MMQ_Y_Q8_0_RDNA1;
-        nwarps = NWARPS_Q8_0_RDNA1;
-    } else if (compute_capability >= CC_VOLTA) {
-        mmq_x  =  MMQ_X_Q8_0_AMPERE;
-        mmq_y  =  MMQ_Y_Q8_0_AMPERE;
-        nwarps = NWARPS_Q8_0_AMPERE;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        mmq_x  =  MMQ_X_Q8_0_PASCAL;
-        mmq_y  =  MMQ_Y_Q8_0_PASCAL;
-        nwarps = NWARPS_Q8_0_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    } else {
-        const bool need_check = true;
-        mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    }
-}
-
-static void ggml_mul_mat_q2_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
-    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
-
-    int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_RDNA2) {
-        mmq_x  =  MMQ_X_Q2_K_RDNA2;
-        mmq_y  =  MMQ_Y_Q2_K_RDNA2;
-        nwarps = NWARPS_Q2_K_RDNA2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        mmq_x  =  MMQ_X_Q2_K_RDNA1;
-        mmq_y  =  MMQ_Y_Q2_K_RDNA1;
-        nwarps = NWARPS_Q2_K_RDNA1;
-    } else if (compute_capability >= CC_VOLTA) {
-        mmq_x  =  MMQ_X_Q2_K_AMPERE;
-        mmq_y  =  MMQ_Y_Q2_K_AMPERE;
-        nwarps = NWARPS_Q2_K_AMPERE;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        mmq_x  =  MMQ_X_Q2_K_PASCAL;
-        mmq_y  =  MMQ_Y_Q2_K_PASCAL;
-        nwarps = NWARPS_Q2_K_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    } else {
-        const bool need_check = true;
-        mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    }
-}
-
-static void ggml_mul_mat_q3_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
-    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
-
-#if QK_K == 256
-
-    int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_RDNA2) {
-        mmq_x  =  MMQ_X_Q3_K_RDNA2;
-        mmq_y  =  MMQ_Y_Q3_K_RDNA2;
-        nwarps = NWARPS_Q3_K_RDNA2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        mmq_x  =  MMQ_X_Q3_K_RDNA1;
-        mmq_y  =  MMQ_Y_Q3_K_RDNA1;
-        nwarps = NWARPS_Q3_K_RDNA1;
-    } else if (compute_capability >= CC_VOLTA) {
-        mmq_x  =  MMQ_X_Q3_K_AMPERE;
-        mmq_y  =  MMQ_Y_Q3_K_AMPERE;
-        nwarps = NWARPS_Q3_K_AMPERE;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        mmq_x  =  MMQ_X_Q3_K_PASCAL;
-        mmq_y  =  MMQ_Y_Q3_K_PASCAL;
-        nwarps = NWARPS_Q3_K_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    } else {
-        const bool need_check = true;
-        mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    }
-#endif
-}
-
-static void ggml_mul_mat_q4_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
-    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
-
-    int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_RDNA2) {
-        mmq_x  =  MMQ_X_Q4_K_RDNA2;
-        mmq_y  =  MMQ_Y_Q4_K_RDNA2;
-        nwarps = NWARPS_Q4_K_RDNA2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        mmq_x  =  MMQ_X_Q4_K_RDNA1;
-        mmq_y  =  MMQ_Y_Q4_K_RDNA1;
-        nwarps = NWARPS_Q4_K_RDNA1;
-    } else if (compute_capability >= CC_VOLTA) {
-        mmq_x  =  MMQ_X_Q4_K_AMPERE;
-        mmq_y  =  MMQ_Y_Q4_K_AMPERE;
-        nwarps = NWARPS_Q4_K_AMPERE;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        mmq_x  =  MMQ_X_Q4_K_PASCAL;
-        mmq_y  =  MMQ_Y_Q4_K_PASCAL;
-        nwarps = NWARPS_Q4_K_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    } else {
-        const bool need_check = true;
-        mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    }
-}
-
-static void ggml_mul_mat_q5_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
-    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
-
-    int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_RDNA2) {
-        mmq_x  =  MMQ_X_Q5_K_RDNA2;
-        mmq_y  =  MMQ_Y_Q5_K_RDNA2;
-        nwarps = NWARPS_Q5_K_RDNA2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        mmq_x  =  MMQ_X_Q5_K_RDNA1;
-        mmq_y  =  MMQ_Y_Q5_K_RDNA1;
-        nwarps = NWARPS_Q5_K_RDNA1;
-    } else if (compute_capability >= CC_VOLTA) {
-        mmq_x  =  MMQ_X_Q5_K_AMPERE;
-        mmq_y  =  MMQ_Y_Q5_K_AMPERE;
-        nwarps = NWARPS_Q5_K_AMPERE;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        mmq_x  =  MMQ_X_Q5_K_PASCAL;
-        mmq_y  =  MMQ_Y_Q5_K_PASCAL;
-        nwarps = NWARPS_Q5_K_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    } else {
-        const bool need_check = true;
-        mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    }
-}
-
-static void ggml_mul_mat_q6_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
-    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
-
-    int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
-
-    int mmq_x, mmq_y, nwarps;
-    if (compute_capability >= CC_RDNA2) {
-        mmq_x  =  MMQ_X_Q6_K_RDNA2;
-        mmq_y  =  MMQ_Y_Q6_K_RDNA2;
-        nwarps = NWARPS_Q6_K_RDNA2;
-    } else if (compute_capability >= CC_OFFSET_AMD) {
-        mmq_x  =  MMQ_X_Q6_K_RDNA1;
-        mmq_y  =  MMQ_Y_Q6_K_RDNA1;
-        nwarps = NWARPS_Q6_K_RDNA1;
-    } else if (compute_capability >= CC_VOLTA) {
-        mmq_x  =  MMQ_X_Q6_K_AMPERE;
-        mmq_y  =  MMQ_Y_Q6_K_AMPERE;
-        nwarps = NWARPS_Q6_K_AMPERE;
-    } else if (compute_capability >= MIN_CC_DP4A) {
-        mmq_x  =  MMQ_X_Q6_K_PASCAL;
-        mmq_y  =  MMQ_Y_Q6_K_PASCAL;
-        nwarps = NWARPS_Q6_K_PASCAL;
-    } else {
-        GGML_ASSERT(false);
-    }
-
-    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
-    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
-
-    if (nrows_x % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    } else {
-        const bool need_check = true;
-        mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
-            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
-    }
-}
+#define MMQ_SWITCH_CASE(type_suffix)                                                                        \
+    case GGML_TYPE_Q##type_suffix: if (row_diff % arch_config.y == 0) {                                     \
+        const bool need_check = false;                                                                      \
+        mul_mat_q##type_suffix<need_check><<<block_nums, block_dims, 0, stream>>>                           \
+            (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst); \
+    } else {                                                                                                \
+        const bool need_check = true;                                                                       \
+        mul_mat_q##type_suffix<need_check><<<block_nums, block_dims, 0, stream>>>                           \
+            (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst); \
+    } break;                                                                                                \
 
 void ggml_cuda_op_mul_mat_q(
     ggml_backend_cuda_context & ctx,
@@ -2190,47 +1458,84 @@ void ggml_cuda_op_mul_mat_q(
     const int64_t row_diff = row_high - row_low;
 
     int id = ggml_cuda_get_device();
+    const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     // the main device has a larger memory buffer to hold the results from all GPUs
     // nrows_dst == nrows of the matrix that the kernel writes into
     const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
 
+    mmq_config_t mmq_config;
+
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            mmq_config = MMQ_CONFIG_Q4_0;
             break;
         case GGML_TYPE_Q4_1:
-            ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            mmq_config = MMQ_CONFIG_Q4_1;
             break;
         case GGML_TYPE_Q5_0:
-            ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            mmq_config = MMQ_CONFIG_Q5_0;
             break;
         case GGML_TYPE_Q5_1:
-            ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            mmq_config = MMQ_CONFIG_Q5_1;
             break;
         case GGML_TYPE_Q8_0:
-            ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            mmq_config = MMQ_CONFIG_Q8_0;
             break;
         case GGML_TYPE_Q2_K:
-            ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            mmq_config = MMQ_CONFIG_Q2_K;
             break;
         case GGML_TYPE_Q3_K:
-            ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            mmq_config = MMQ_CONFIG_Q3_K;
             break;
         case GGML_TYPE_Q4_K:
-            ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            mmq_config = MMQ_CONFIG_Q4_K;
             break;
         case GGML_TYPE_Q5_K:
-            ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            mmq_config = MMQ_CONFIG_Q5_K;
             break;
         case GGML_TYPE_Q6_K:
-            ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+            mmq_config = MMQ_CONFIG_Q6_K;
             break;
         default:
             GGML_ASSERT(false);
             break;
     }
 
+    mmq_arch_config_t arch_config;
+    if (compute_capability >= CC_RDNA2) {
+        arch_config = mmq_config.rdna2;
+    } else if (compute_capability >= CC_OFFSET_AMD) {
+        arch_config = mmq_config.rdna1;
+    } else if (compute_capability >= CC_VOLTA) {
+        arch_config = mmq_config.ampere;
+    } else if (compute_capability >= MIN_CC_DP4A) {
+        arch_config = mmq_config.pascal;
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const int block_num_x = (row_diff   + arch_config.y - 1) / arch_config.y;
+    const int block_num_y = (src1_ncols + arch_config.x - 1) / arch_config.x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, arch_config.nwarps, 1);
+
+    switch (src0->type) {
+        MMQ_SWITCH_CASE(4_0)
+        MMQ_SWITCH_CASE(4_1)
+        MMQ_SWITCH_CASE(5_0)
+        MMQ_SWITCH_CASE(5_1)
+        MMQ_SWITCH_CASE(8_0)
+        MMQ_SWITCH_CASE(2_K)
+        MMQ_SWITCH_CASE(3_K)
+        MMQ_SWITCH_CASE(4_K)
+        MMQ_SWITCH_CASE(5_K)
+        MMQ_SWITCH_CASE(6_K)
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+
     GGML_UNUSED(src1);
     GGML_UNUSED(dst);
     GGML_UNUSED(src1_ddf_i);