]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: tuned mul_mat_q kernels (#2546)
authorJohannes Gäßler <redacted>
Wed, 9 Aug 2023 07:42:34 +0000 (09:42 +0200)
committerGitHub <redacted>
Wed, 9 Aug 2023 07:42:34 +0000 (09:42 +0200)
Makefile
README.md
ggml-cuda.cu

index 32598edfe847d6db1e67ab153fd6b8dd7d617446..f01bf0c8324edf8631e6159c2025b6c398942efd 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -253,11 +253,6 @@ ifdef LLAMA_CUDA_KQUANTS_ITER
 else
        NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
 endif
-ifdef LLAMA_CUDA_MMQ_Y
-       NVCCFLAGS += -DGGML_CUDA_MMQ_Y=$(LLAMA_CUDA_MMQ_Y)
-else
-       NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64
-endif # LLAMA_CUDA_MMQ_Y
 #ifdef LLAMA_CUDA_CUBLAS
 #      NVCCFLAGS += -DGGML_CUDA_CUBLAS
 #endif # LLAMA_CUDA_CUBLAS
index 2ece294b7c94709be0b0daefaf0d9f13f28f72ce..6900b1152e7362b81e19f4aedbd18443c236414c 100644 (file)
--- a/README.md
+++ b/README.md
@@ -406,7 +406,6 @@ Building the program with BLAS support may lead to some performance improvements
 --->
   | Option                  | Legal values           | Default | Description |
   |-------------------------|------------------------|---------|-------------|
-  | LLAMA_CUDA_MMQ_Y        | Positive integer >= 32 |      64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. |
   | LLAMA_CUDA_FORCE_DMMV   | Boolean                |   false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
   | LLAMA_CUDA_DMMV_X       | Positive integer >= 32 |      32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
   | LLAMA_CUDA_MMV_Y        | Positive integer       |       1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
index 9d42efb0d0b03a98b11c16f19260a10c8b9009da..6390b1158b6a6c5b792ff8bd9f1731c7c499596f 100644 (file)
@@ -14,6 +14,7 @@
 #include "ggml.h"
 
 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define CC_TURING   700
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -262,10 +263,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 
-#ifndef GGML_CUDA_MMQ_Y
-#define GGML_CUDA_MMQ_Y 64
-#endif // GGML_CUDA_MMQ_Y
-
 // dmmv = dequantize_mul_mat_vec
 #ifndef GGML_CUDA_DMMV_X
 #define GGML_CUDA_DMMV_X 32
@@ -285,6 +282,20 @@ struct ggml_tensor_extra_gpu {
     cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
 };
 
+static int g_device_count = -1;
+static int g_main_device = 0;
+static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
+static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
+static bool g_mul_mat_q = false;
+
+static void * g_scratch_buffer = nullptr;
+static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
+static size_t g_scratch_offset = 0;
+
+static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
+
 static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -1549,8 +1560,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
 #else
     const float2 dm8f = __half22float2(dm8);
     const float2 ds8f = __half22float2(ds8);
-    const float d8d8 = dm8.x * ds8.x;
-    const float m8s8 = dm8.y * ds8.y;
+    const float d8d8 = dm8f.x * ds8f.x;
+    const float m8s8 = dm8f.y * ds8f.y;
 #endif // GGML_CUDA_F16
 
     // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
@@ -1884,21 +1895,21 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
     return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
 }
 
-static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int  tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
-    __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0) + GGML_CUDA_MMQ_Y/QI4_0];
+    __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y];
+    __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
 
     *x_ql = tile_x_qs;
     *x_dm = (half2 *) tile_x_d;
 }
 
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  8);
+    __builtin_assume(i_offset <  nwarps);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
@@ -1910,7 +1921,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
     float * x_dmf = (float *) x_dm;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + i_offset;
 
         if (need_check) {
@@ -1920,39 +1931,30 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
         const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
 
         x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
-        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
+        // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
     }
 
-//     const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
-//     const int kbxd = k % blocks_per_tile_x_row;
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+    const int kbxd = k % blocks_per_tile_x_row;
 
-// #pragma unroll
-//     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_0) {
-//         FIXME out-of-bounds
-//         const int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
+        int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
 
-//         if (i >= GGML_CUDA_MMQ_Y) {
-//             return;
-//         }
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
-//         const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
 
-//         x_dm[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd].x = bxi->d;
-//     }
+        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
+    }
 }
 
 static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    __builtin_assume(i >= 0);
-    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
-    __builtin_assume(j >= 0);
-    __builtin_assume(j <  WARP_SIZE);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
-    __builtin_assume(k % VDR_Q4_0_Q8_1_MMQ == 0);
-
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const float * x_dmf = (float *) x_dm;
 
@@ -1960,13 +1962,13 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
 
 #pragma unroll
     for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
-        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI4_0];
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
     }
 
     return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
         (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
-         y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
 }
 
 static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
@@ -1987,21 +1989,21 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
     return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
 }
 
-static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int   tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE) +     + GGML_CUDA_MMQ_Y];
-    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_1) + GGML_CUDA_MMQ_Y/QI4_1];
+    __shared__ int   tile_x_qs[mmq_y * (WARP_SIZE) +     + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
 
     *x_ql = tile_x_qs;
     *x_dm = tile_x_dm;
 }
 
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  8);
+    __builtin_assume(i_offset <  nwarps);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
@@ -2011,7 +2013,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
     const block_q4_1 * bx0 = (block_q4_1 *) vx;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + i_offset;
 
         if (need_check) {
@@ -2027,7 +2029,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
     const int kbxd = k % blocks_per_tile_x_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_1) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
         int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
 
         if (need_check) {
@@ -2044,27 +2046,19 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    __builtin_assume(i >= 0);
-    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
-    __builtin_assume(j >= 0);
-    __builtin_assume(j <  WARP_SIZE);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
-    __builtin_assume(k % VDR_Q4_1_Q8_1_MMQ == 0);
-
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
 
     int u[2*VDR_Q4_1_Q8_1_MMQ];
 
 #pragma unroll
     for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
-        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI4_1];
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
     }
 
     return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
         (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
-         y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
 }
 
 static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
@@ -2087,21 +2081,21 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
     return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
 }
 
-static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int  tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
-    __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0) + GGML_CUDA_MMQ_Y/QI5_0];
+    __shared__ int  tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
+    __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
 
     *x_ql = tile_x_ql;
     *x_dm = (half2 *) tile_x_d;
 }
 
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  8);
+    __builtin_assume(i_offset <  nwarps);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
@@ -2111,7 +2105,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
     const block_q5_0 * bx0 = (block_q5_0 *) vx;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + i_offset;
 
         if (need_check) {
@@ -2147,7 +2141,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
     float * x_dmf = (float *) x_dm;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_0) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
         int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
 
         if (need_check) {
@@ -2164,14 +2158,6 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    __builtin_assume(i >= 0);
-    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
-    __builtin_assume(j >= 0);
-    __builtin_assume(j <  WARP_SIZE);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
-    __builtin_assume(k % VDR_Q5_0_Q8_1_MMQ == 0);
-
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
     const float * x_dmf = (const float *) x_dm;
@@ -2181,12 +2167,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
 
 #pragma unroll
     for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
-        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI5_0];
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
     }
 
     return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
-        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
 }
 
 static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
@@ -2209,21 +2195,21 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
     return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
 }
 
-static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
-    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1) + GGML_CUDA_MMQ_Y/QI5_1];
+    __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
 
     *x_ql = tile_x_ql;
     *x_dm = tile_x_dm;
 }
 
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  8);
+    __builtin_assume(i_offset < nwarps);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
@@ -2233,7 +2219,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
     const block_q5_1 * bx0 = (block_q5_1 *) vx;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + i_offset;
 
         if (need_check) {
@@ -2266,7 +2252,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
     const int kbxd = k % blocks_per_tile_x_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_1) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
         int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
 
         if (need_check) {
@@ -2283,14 +2269,6 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    __builtin_assume(i >= 0);
-    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
-    __builtin_assume(j >= 0);
-    __builtin_assume(j <  WARP_SIZE);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
-    __builtin_assume(k % VDR_Q5_1_Q8_1_MMQ == 0);
-
     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
     const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
 
@@ -2298,12 +2276,12 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
 
 #pragma unroll
     for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
-        u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
-        u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI5_1];
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
     }
 
     return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
-        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
 }
 
 static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
@@ -2323,21 +2301,21 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
     return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x);
 }
 
-static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int  tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
-    __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0) + GGML_CUDA_MMQ_Y/QI8_0];
+    __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y];
+    __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
 
     *x_ql = tile_x_qs;
     *x_dm = (half2 *) tile_x_d;
 }
 
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  8);
+    __builtin_assume(i_offset <  nwarps);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
@@ -2348,7 +2326,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_
     const block_q8_0 * bx0 = (block_q8_0 *) vx;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + i_offset;
 
         if (need_check) {
@@ -2358,41 +2336,29 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_
         const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
 
         x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
-        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbx] = bxi->d;
     }
 
-//     const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
-//     const int kbxd = k % blocks_per_tile_x_row;
+    const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
+    const int kbxd = k % blocks_per_tile_x_row;
 
-// #pragma unroll
-//     for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI8_0) {
-//         FIXME out-of-bounds
-//         const int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
+        int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
 
-// #if GGML_CUDA_MMQ_Y < 64
-//         if (i >= GGML_CUDA_MMQ_Y) {
-//             return;
-//         }
-// #endif // GGML_CUDA_MMQ_Y < 64
+        if (need_check) {
+            i = min(i, i_max);
+        }
 
-//         const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
 
-//         x_dm[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd].x = bxi->d;
-//     }
+        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+    }
 }
 
 static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    __builtin_assume(i >= 0);
-    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
-    __builtin_assume(j >= 0);
-    __builtin_assume(j <  WARP_SIZE);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
-    __builtin_assume(k % VDR_Q8_0_Q8_1_MMQ == 0);
-
     const float * x_dmf = (const float *) x_dm;
     const float * y_df  = (const float *) y_ds;
 
@@ -2424,23 +2390,23 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
     return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
 }
 
-static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
-    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI2_K) + GGML_CUDA_MMQ_Y/QI2_K];
-    __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/4)     + GGML_CUDA_MMQ_Y/4];
+    __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
+    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/4)     + mmq_y/4];
 
     *x_ql = tile_x_ql;
     *x_dm = tile_x_dm;
     *x_sc = tile_x_sc;
 }
 
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  8);
+    __builtin_assume(i_offset <  nwarps);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
@@ -2450,7 +2416,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q2_
     const block_q2_K * bx0 = (block_q2_K *) vx;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + i_offset;
 
         if (need_check) {
@@ -2466,8 +2432,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q2_
     const int kbxd = k % blocks_per_tile_x_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI2_K) {
-        int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
+        int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -2479,7 +2445,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q2_
     }
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
         int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
 
         if (need_check) {
@@ -2496,14 +2462,6 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    __builtin_assume(i >= 0);
-    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
-    __builtin_assume(j >= 0);
-    __builtin_assume(j <  WARP_SIZE);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
-    __builtin_assume(k % VDR_Q2_K_Q8_1_MMQ == 0);
-
     const int kbx = k / QI2_K;
     const int ky  = (k % QI2_K) * QR2_K;
     const float * y_df = (const float *) y_ds;
@@ -2520,7 +2478,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
 
     const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
 
-    const int index_y = j * (QR2_K*WARP_SIZE) + QR2_K*k;
+    const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
     return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
 }
 
@@ -2551,12 +2509,12 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
     return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
 }
 
-static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
-    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI3_K) + GGML_CUDA_MMQ_Y/QI3_K];
-    __shared__ int   tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2)     + GGML_CUDA_MMQ_Y/2];
-    __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/4)     + GGML_CUDA_MMQ_Y/4];
+    __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K];
+    __shared__ int   tile_x_qh[mmq_y * (WARP_SIZE/2)     + mmq_y/2];
+    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/4)     + mmq_y/4];
 
     *x_ql = tile_x_ql;
     *x_dm = tile_x_dm;
@@ -2564,12 +2522,12 @@ static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 **
     *x_sc = tile_x_sc;
 }
 
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  8);
+    __builtin_assume(i_offset <  nwarps);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
@@ -2579,7 +2537,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
     const block_q3_K * bx0 = (block_q3_K *) vx;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + i_offset;
 
         if (need_check) {
@@ -2596,8 +2554,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
     float * x_dmf = (float *) x_dm;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI3_K) {
-        int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
+        int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -2609,7 +2567,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
     }
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 2) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
         int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
 
         if (need_check) {
@@ -2623,7 +2581,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
     }
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
         int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
 
         if (need_check) {
@@ -2652,14 +2610,6 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    __builtin_assume(i >= 0);
-    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
-    __builtin_assume(j >= 0);
-    __builtin_assume(j <  WARP_SIZE);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
-    __builtin_assume(k % VDR_Q3_K_Q8_1_MMQ == 0);
-
     const int kbx  = k / QI3_K;
     const int ky  = (k % QI3_K) * QR3_K;
     const float * x_dmf = (const float *) x_dm;
@@ -2681,7 +2631,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
         v[l] = __vsubss4(vll, vlh);
     }
 
-    const int index_y = j * (QR3_K*WARP_SIZE) + k*QR3_K;
+    const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
     return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
 }
 
@@ -2778,23 +2728,23 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 #endif
 }
 
-static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE)       + GGML_CUDA_MMQ_Y];
-    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K) + GGML_CUDA_MMQ_Y/QI4_K];
-    __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)     + GGML_CUDA_MMQ_Y/8];
+    __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
+    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/8)     + mmq_y/8];
 
     *x_ql = tile_x_ql;
     *x_dm = tile_x_dm;
     *x_sc = tile_x_sc;
 }
 
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  8);
+    __builtin_assume(i_offset <  nwarps);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
@@ -2804,7 +2754,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
     const block_q4_K * bx0 = (block_q4_K *) vx;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + i_offset;
 
         if (need_check) {
@@ -2820,8 +2770,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
     const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_K) {
-        int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
+        int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -2833,8 +2783,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
     }
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
-        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -2858,14 +2808,6 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    __builtin_assume(i >= 0);
-    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
-    __builtin_assume(j >= 0);
-    __builtin_assume(j <  WARP_SIZE);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
-    __builtin_assume(k % VDR_Q4_K_Q8_1_MMQ == 0);
-
     int v[QR4_K*VDR_Q4_K_Q8_1_MMQ];
 
 #pragma unroll
@@ -2876,7 +2818,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
 
     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
 
-    const int index_y = j * (QR4_K*WARP_SIZE) + QR4_K*k;
+    const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
     return vec_dot_q4_K_q8_1_impl_mmq(v, &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
 }
 
@@ -2969,23 +2911,23 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 #endif
 }
 
-static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
-    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_K) + GGML_CUDA_MMQ_Y/QI5_K];
-    __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)     + GGML_CUDA_MMQ_Y/8];
+    __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
+    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/8)     + mmq_y/8];
 
     *x_ql = tile_x_ql;
     *x_dm = tile_x_dm;
     *x_sc = tile_x_sc;
 }
 
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  8);
+    __builtin_assume(i_offset <  nwarps);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
@@ -2995,7 +2937,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
     const block_q5_K * bx0 = (block_q5_K *) vx;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + i_offset;
 
         if (need_check) {
@@ -3024,8 +2966,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
     const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_K) {
-        int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
+        int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -3037,8 +2979,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
     }
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
-        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -3062,18 +3004,10 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    __builtin_assume(i >= 0);
-    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
-    __builtin_assume(j >= 0);
-    __builtin_assume(j <  WARP_SIZE);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
-    __builtin_assume(k % VDR_Q5_K_Q8_1_MMQ == 0);
-
     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
 
-    const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
-    const int index_y = j * (QR5_K*WARP_SIZE)     + QR5_K*k;
+    const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k;
+    const int index_y = j * WARP_SIZE             + (QR5_K*k) % WARP_SIZE;
     return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
 }
 
@@ -3103,23 +3037,23 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
     return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
 }
 
-static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
 
-    __shared__ int   tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE)     + GGML_CUDA_MMQ_Y];
-    __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K) + GGML_CUDA_MMQ_Y/QI6_K];
-    __shared__ int   tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)     + GGML_CUDA_MMQ_Y/8];
+    __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
+    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/8)     + mmq_y/8];
 
     *x_ql = tile_x_ql;
     *x_dm = tile_x_dm;
     *x_sc = tile_x_sc;
 }
 
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
 
     __builtin_assume(i_offset >= 0);
-    __builtin_assume(i_offset <  8);
+    __builtin_assume(i_offset <  nwarps);
     __builtin_assume(k >= 0);
     __builtin_assume(k <  WARP_SIZE);
 
@@ -3129,7 +3063,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_
     const block_q6_K * bx0 = (block_q6_K *) vx;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
         int i = i0 + i_offset;
 
         if (need_check) {
@@ -3159,8 +3093,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_
     float * x_dmf = (float *) x_dm;
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI6_K) {
-        int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
+        int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -3172,8 +3106,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_
     }
 
 #pragma unroll
-    for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
-        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -3189,25 +3123,17 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
 
-    __builtin_assume(i >= 0);
-    __builtin_assume(i <  GGML_CUDA_MMQ_Y);
-    __builtin_assume(j >= 0);
-    __builtin_assume(j <  WARP_SIZE);
-    __builtin_assume(k >= 0);
-    __builtin_assume(k <  WARP_SIZE);
-    __builtin_assume(k % VDR_Q6_K_Q8_1_MMQ == 0);
-
     const float * x_dmf = (const float *) x_dm;
     const float * y_df  = (const float *) y_ds;
 
     const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
 
-    const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
-    const int index_y = j * (QR6_K*WARP_SIZE)     + QR6_K*k;
+    const int index_x = i * (QR6_K*WARP_SIZE + 1) +  QR6_K*k;
+    const int index_y = j * WARP_SIZE             + (QR6_K*k) % WARP_SIZE;
     return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
 }
 
-template <int qk, int qr, int qi, bool need_sum, typename block_q_t,
+template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
               allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
 static __global__ void mul_mat_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
@@ -3222,14 +3148,11 @@ static __global__ void mul_mat_q(
 
     const int & ncols_dst = ncols_y;
 
-    const int tid_x = threadIdx.x;
-    const int tid_y = threadIdx.y;
-
-    const int row_dst_0 = blockIdx.x*GGML_CUDA_MMQ_Y;
+    const int row_dst_0 = blockIdx.x*mmq_y;
     const int & row_x_0 = row_dst_0;
-    const int row_dst = row_dst_0 + tid_x;
+    const int row_dst = row_dst_0 + threadIdx.x;
 
-    const int col_dst_0 = blockIdx.y*WARP_SIZE;
+    const int col_dst_0 = blockIdx.y*mmq_x;
     const int & col_y_0 = col_dst_0;
 
     int   * tile_x_ql = nullptr;
@@ -3239,64 +3162,65 @@ static __global__ void mul_mat_q(
 
     allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
 
-    const int blocks_per_tile_y_col = qr*WARP_SIZE/QI8_1;
-
-    __shared__ int    tile_y_qs[(WARP_SIZE) * (qr*WARP_SIZE)];
-    __shared__ half2  tile_y_ds[(WARP_SIZE) * blocks_per_tile_y_col];
+    __shared__ int    tile_y_qs[mmq_x * WARP_SIZE];
+    __shared__ half2  tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
 
-    float sum[GGML_CUDA_MMQ_Y/WARP_SIZE][4] = {0.0f};
+    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {0.0f};
 
     for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
 
         load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
-                   tid_y, nrows_x-row_x_0-1, tid_x, blocks_per_row_x);
+                   threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
 
+#pragma unroll
         for (int ir = 0; ir < qr; ++ir) {
-            const int kqs = ir*WARP_SIZE + tid_x;
+            const int kqs = ir*WARP_SIZE + threadIdx.x;
             const int kbxd = kqs / QI8_1;
 
-            for (int i = 0; i < WARP_SIZE; i += 8) {
-                const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
+#pragma unroll
+            for (int i = 0; i < mmq_x; i += nwarps) {
+                const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
 
                 const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
 
-                tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = get_int_from_int8_aligned(by0->qs, tid_x % QI8_1);
+                const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
+                tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
             }
-        }
 
-        for (int ids0 = 0; ids0 < WARP_SIZE; ids0 += 8 * (WARP_SIZE/blocks_per_tile_y_col)) {
-            const int ids = (ids0 + tid_y * (WARP_SIZE/blocks_per_tile_y_col) + tid_x / blocks_per_tile_y_col) % WARP_SIZE;
-            const int kby = tid_x % blocks_per_tile_y_col;
-            const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
-
-            // if the sum is not needed it's faster to transform the scale to f32 ahead of time
-            const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby].ds;
-            half2       * dsi_dst = &tile_y_ds[ids * (qr*WARP_SIZE/QI8_1) + kby];
-            if (need_sum) {
-                *dsi_dst = *dsi_src;
-            } else {
-                float * dfi_dst = (float *) dsi_dst;
-                *dfi_dst = (*dsi_src).x;
+#pragma unroll
+            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
+                const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
+                const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
+                const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
+
+                // if the sum is not needed it's faster to transform the scale to f32 ahead of time
+                const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
+                half2       * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
+                if (need_sum) {
+                    *dsi_dst = *dsi_src;
+                } else {
+                    float * dfi_dst = (float *) dsi_dst;
+                    *dfi_dst = (*dsi_src).x;
+                }
             }
-        }
 
-        __syncthreads();
+            __syncthreads();
 
-#if __CUDA_ARCH__ >= 700 // Unrolling the loop is slower on Pascal
+// #pragma unroll // unrolling this loop causes too much register pressure
+            for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
 #pragma unroll
-#endif // __CUDA_ARCH__ >= 700
-        for (int k = 0; k < WARP_SIZE; k += vdr) {
+                for (int j = 0; j < mmq_x; j += nwarps) {
 #pragma unroll
-            for (int j = 0; j < WARP_SIZE; j += 8) {
-#pragma unroll
-                for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) {
-                    sum[i/WARP_SIZE][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
-                                                     tid_x + i, tid_y + j, k);
+                    for (int i = 0; i < mmq_y; i += WARP_SIZE) {
+                        sum[i/WARP_SIZE][j/nwarps] += vec_dot(
+                            tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
+                            threadIdx.x + i, threadIdx.y + j, k);
+                    }
                 }
             }
-        }
 
-        __syncthreads();
+            __syncthreads();
+        }
     }
 
 
@@ -3304,15 +3228,15 @@ static __global__ void mul_mat_q(
         return;
     }
 
-    for (int j = 0; j < WARP_SIZE; j += 8) {
-        const int col_dst = col_dst_0 + j + tid_y;
+    for (int j = 0; j < mmq_x; j += nwarps) {
+        const int col_dst = col_dst_0 + j + threadIdx.y;
 
         if (col_dst >= ncols_dst) {
             return;
         }
 
-        for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) {
-            dst[col_dst*nrows_dst + row_dst + i] = sum[i/WARP_SIZE][j/8];
+        for (int i = 0; i < mmq_y; i += WARP_SIZE) {
+            dst[col_dst*nrows_dst + row_dst + i] = sum[i/WARP_SIZE][j/nwarps];
         }
     }
 }
@@ -4014,17 +3938,52 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
-    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
-    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<false>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_TURING) {
+        const int mmq_x  = 64;
+        const int mmq_y  = 128;
+        const int nwarps = 4;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+                load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+                load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     } else {
-        mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<true>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        const int mmq_x  = 64;
+        const int mmq_y  = 64;
+        const int nwarps = 4;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+                load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+                load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     }
 }
 
@@ -4032,17 +3991,53 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
-    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
-    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<false>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_TURING) {
+        const int mmq_x  = 64;
+        const int mmq_y  = 128;
+        const int nwarps = 4;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+                load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+                load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     } else {
-        mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<true>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        const int mmq_x  = 64;
+        const int mmq_y  = 64;
+        const int nwarps = 8;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+                load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+                load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
+
     }
 }
 
@@ -4050,17 +4045,52 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
-    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
-    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<false>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_TURING) {
+        const int mmq_x  = 128;
+        const int mmq_y  = 64;
+        const int nwarps = 4;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+                load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+                load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     } else {
-        mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<true>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        const int mmq_x  = 64;
+        const int mmq_y  = 64;
+        const int nwarps = 8;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+                load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+                load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     }
 }
 
@@ -4068,17 +4098,52 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
-    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
-    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<false>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_TURING) {
+        const int mmq_x  = 128;
+        const int mmq_y  = 64;
+        const int nwarps = 8;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+                load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+                load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     } else {
-        mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<true>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        const int mmq_x  = 64;
+        const int mmq_y  = 64;
+        const int nwarps = 8;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+                load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+                load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     }
 }
 
@@ -4086,17 +4151,52 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
-    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
-    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<false>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_TURING) {
+        const int mmq_x  = 128;
+        const int mmq_y  = 64;
+        const int nwarps = 4;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+                load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+                load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     } else {
-        mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<true>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        const int mmq_x  = 64;
+        const int mmq_y  = 64;
+        const int nwarps = 8;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+                load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+                load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     }
 }
 
@@ -4104,17 +4204,52 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
-    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
-    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<false>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_TURING) {
+        const int mmq_x  = 64;
+        const int mmq_y  = 128;
+        const int nwarps = 4;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+                load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+                load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     } else {
-        mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<true>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        const int mmq_x  = 64;
+        const int mmq_y  = 64;
+        const int nwarps = 8;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+                load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+                load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     }
 }
 
@@ -4122,17 +4257,52 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
-    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
-    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<false>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_TURING) {
+        const int mmq_x  = 128;
+        const int mmq_y  = 128;
+        const int nwarps = 4;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+                load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+                load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     } else {
-        mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<true>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        const int mmq_x  = 64;
+        const int mmq_y  = 64;
+        const int nwarps = 8;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+                load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+                load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     }
 }
 
@@ -4140,17 +4310,52 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
-    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
-    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<false>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_TURING) {
+        const int mmq_x  = 64;
+        const int mmq_y  = 128;
+        const int nwarps = 4;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+                load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+                load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     } else {
-        mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<true>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        const int mmq_x  = 32;
+        const int mmq_y  = 64;
+        const int nwarps = 8;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+                load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+                load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     }
 }
 
@@ -4158,17 +4363,52 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
-    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
-    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<false>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_TURING) {
+        const int mmq_x  = 64;
+        const int mmq_y  = 128;
+        const int nwarps = 4;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+                load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+                load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     } else {
-        mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<true>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        const int mmq_x  = 64;
+        const int mmq_y  = 64;
+        const int nwarps = 8;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+                load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+                load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     }
 }
 
@@ -4176,17 +4416,52 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
-    const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
-    if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
-        mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<false>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    const int compute_capability = g_compute_capabilities[id];
+
+    if (compute_capability >= CC_TURING) {
+        const int mmq_x  = 64;
+        const int mmq_y  = 64;
+        const int nwarps = 4;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+                load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+                load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     } else {
-        mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<true>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
-            <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        const int mmq_x  = 32;
+        const int mmq_y  = 64;
+        const int nwarps = 8;
+
+        const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+        const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+        const dim3 block_nums(block_num_x, block_num_y, 1);
+        const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+        if (nrows_x % mmq_y == 0) {
+            const bool need_check = false;
+            mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+                load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        } else {
+            const bool need_check = true;
+            mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+                load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+        }
     }
 }
 
@@ -4361,20 +4636,6 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
 }
 
 
-static void * g_scratch_buffer = nullptr;
-static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
-static size_t g_scratch_offset = 0;
-
-static int g_device_count = -1;
-static int g_main_device = 0;
-static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
-static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
-static bool g_mul_mat_q = false;
-
-static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
-
-static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
-
 void ggml_init_cublas() {
     static bool initialized = false;
 
@@ -4730,6 +4991,37 @@ inline void ggml_cuda_op_mul_mat_q(
     (void) i1;
 }
 
+static int64_t get_row_rounding(ggml_type type) {
+    int max_compute_capability = INT_MIN;
+    for (int id = 0; id < g_device_count; ++id) {
+        if (max_compute_capability < g_compute_capabilities[id]
+                && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+            max_compute_capability = g_compute_capabilities[id];
+        }
+    }
+
+    switch(type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            return max_compute_capability >= CC_TURING ? 128 : 64;
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+            return 64;
+        case GGML_TYPE_F16:
+            return 1;
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+            return max_compute_capability >= CC_TURING ? 128 : 64;
+        case GGML_TYPE_Q6_K:
+            return 64;
+        default:
+            GGML_ASSERT(false);
+    }
+}
+
 inline void ggml_cuda_op_mul_mat_vec(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -5130,14 +5422,16 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
 
         int64_t row_low, row_high;
         if (split) {
+            const int64_t rounding = get_row_rounding(src0->type);
+
             row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
-            row_low -= row_low % GGML_CUDA_MMQ_Y;
+            row_low -= row_low % rounding;
 
             if (id == g_device_count - 1) {
                 row_high = nrows0;
             } else {
                 row_high = nrows0*g_tensor_split[id + 1];
-                row_high -= row_high % GGML_CUDA_MMQ_Y;
+                row_high -= row_high % rounding;
             }
         } else {
             row_low = 0;
@@ -5616,14 +5910,16 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
             row_low = 0;
             row_high = nrows;
         } else if (backend == GGML_BACKEND_GPU_SPLIT) {
+            const int64_t rounding = get_row_rounding(tensor->type);
+
             row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
-            row_low -= row_low % GGML_CUDA_MMQ_Y;
+            row_low -= row_low % rounding;
 
             if (id == g_device_count - 1) {
                 row_high = nrows;
             } else {
                 row_high = nrows*g_tensor_split[id + 1];
-                row_high -= row_high % GGML_CUDA_MMQ_Y;
+                row_high -= row_high % rounding;
             }
         } else {
             GGML_ASSERT(false);