]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: revise q8_1 data layout for mul_mat_q (llama/7824)
authorJohannes Gäßler <redacted>
Sun, 9 Jun 2024 07:42:25 +0000 (09:42 +0200)
committerGeorgi Gerganov <redacted>
Sun, 16 Jun 2024 15:19:48 +0000 (18:19 +0300)
ggml-cuda.cu
ggml-cuda/mmq.cu
ggml-cuda/mmq.cuh
ggml-cuda/quantize.cu
ggml-cuda/quantize.cuh

index dad8a9e2dafe78c3c3e44feaa371c6f27e548566..af10f21a0a92a0e0a11fafe6b98162830e77b24e 100644 (file)
@@ -1347,10 +1347,30 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
     GGML_UNUSED(main_device);
 }
 
+static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
+    void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
+
+#if !defined(GGML_USE_HIPBLAS)
+    // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
+    cudaMemcpy3DPeerParms p = {};
+    p.dstDevice = dstDevice;
+    p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
+    p.srcDevice = srcDevice;
+    p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
+    p.extent = make_cudaExtent(width, height, 1);
+    return cudaMemcpy3DPeerAsync(&p, stream);
+#else
+    // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
+    GGML_UNUSED(dstDevice);
+    GGML_UNUSED(srcDevice);
+    return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
+#endif // !defined(GGML_USE_HIPBLAS)
+}
+
 static void ggml_cuda_op_mul_mat(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
-    const bool convert_src1_to_q8_1) {
+    quantize_cuda_t quantize_src1) {
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
@@ -1407,7 +1427,9 @@ static void ggml_cuda_op_mul_mat(
     }
 
     struct dev_data {
-        ggml_cuda_pool_alloc<char>  src0_dd_alloc;
+        int cc;
+
+        ggml_cuda_pool_alloc<char>   src0_dd_alloc;
         ggml_cuda_pool_alloc<float> src1_ddf_alloc;
         ggml_cuda_pool_alloc<char>  src1_ddq_alloc;
         ggml_cuda_pool_alloc<float>   dst_dd_alloc;
@@ -1426,6 +1448,8 @@ static void ggml_cuda_op_mul_mat(
     int used_devices = 0;
 
     for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        dev[id].cc = ggml_cuda_info().devices[id].cc;
+
         // by default, use all rows
         dev[id].row_low  = 0;
         dev[id].row_high = ne01;
@@ -1476,11 +1500,15 @@ static void ggml_cuda_op_mul_mat(
             dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
         }
 
-        if (convert_src1_to_q8_1) {
-            dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
+        if (quantize_src1) {
+            size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
+            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
+            }
+            dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
 
             if (src1_on_device && src1_is_contiguous) {
-                quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
+                quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
                 CUDA_CHECK(cudaGetLastError());
             }
         }
@@ -1526,7 +1554,12 @@ static void ggml_cuda_op_mul_mat(
                 const int64_t i03 = i0 / ne12;
                 const int64_t i02 = i0 % ne12;
 
-                const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
+                size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+                if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                    src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
+                } else {
+                    src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+                }
 
                 // for split tensors the data begins at i0 == i0_offset_low
                 char  *  src0_dd_i =  dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
@@ -1543,10 +1576,17 @@ static void ggml_cuda_op_mul_mat(
                 // copy src0, src1 to device if necessary
                 if (src1_is_contiguous) {
                     if (id != ctx.device) {
-                        if (convert_src1_to_q8_1) {
+                        if (quantize_src1) {
                             char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
-                            CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, ctx.device,
-                                                            src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
+                            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                                const size_t pitch = ne11*sizeof(block_q8_1_mmq);
+                                const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
+                                const size_t height = src1_padded_col_size/(4*QK8_1);
+                                CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
+                            } else {
+                                CUDA_CHECK(cudaMemcpyPeerAsync(
+                                    src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
+                            }
                         } else {
                             float * src1_ddf_i_source = (float *) src1->data;
                             src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
@@ -1561,8 +1601,8 @@ static void ggml_cuda_op_mul_mat(
                     GGML_ASSERT(false);
                 }
 
-                if (convert_src1_to_q8_1 && !src1_is_contiguous) {
-                    quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
+                if (quantize_src1 && !src1_is_contiguous) {
+                    quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
                     CUDA_CHECK(cudaGetLastError());
                 }
 
@@ -1587,22 +1627,8 @@ static void ggml_cuda_op_mul_mat(
                         float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
                         GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
                         dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
-#if !defined(GGML_USE_HIPBLAS)
-                        // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
-                        cudaMemcpy3DPeerParms p = {};
-                        p.dstDevice = ctx.device;
-                        p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols);
-                        p.srcDevice = id;
-                        p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
-                        p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
-                        CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
-#else
-                        // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
-                        CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
-                                                        dst_dd_i, row_diff*sizeof(float),
-                                                        row_diff*sizeof(float), src1_ncols,
-                                                        cudaMemcpyDeviceToDevice, stream));
-#endif
+                        CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
+                            dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
                     } else {
                         float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
                         GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
@@ -1941,13 +1967,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         // KQ + KQV multi-batch
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
     } else if (use_dequantize_mul_mat_vec) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
     } else if (use_mul_mat_vec_q) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
     } else if (use_mul_mat_q) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
     } else {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
     }
 }
 
index 58799e4caf6f8b6493134de36a0968634725eec3..1d6b9e6982b6e3cb78fc9ced9fbd9c01d47a7c75 100644 (file)
@@ -11,6 +11,7 @@ void ggml_cuda_op_mul_mat_q(
     const int64_t nb01 = src0->nb[1];
 
     const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
     GGML_ASSERT(ne10 % QK8_1 == 0);
 
     const int64_t ne0 = dst->ne[0];
@@ -25,7 +26,7 @@ void ggml_cuda_op_mul_mat_q(
     // nrows_dst == nrows of the matrix that the kernel writes into
     const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
 
-    const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst};
+    const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
 
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
index 6744cce6d785f7a951770b2780e47083c7c44dcf..3ccae8a0c36fa2e6ee164a59d43a6d91e5bdb1c3 100644 (file)
@@ -1,15 +1,26 @@
+#pragma once
+
 #include "common.cuh"
 #include "vecdotq.cuh"
 
 #include <climits>
 #include <cstdint>
 
+#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+
 typedef void (*load_tiles_mmq_t)(
     const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
     int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
 typedef void (*vec_dot_mmq_t)(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0);
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0);
+
+struct block_q8_1_mmq {
+    half2  ds[4];
+    int8_t qs[4*QK8_1];
+};
+static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
+static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1),      "Unexpected block_q8_1_mmq size");
 
 struct tile_x_sizes {
     int ql;
@@ -132,10 +143,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
 
+    const float * x_dmf = (const float *) x_dm;
+    const int   * y_qs  = (const int   *) y + 4;
+    const half2 * y_ds  = (const half2 *) y;
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -145,19 +160,18 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
             const int i = i0 + threadIdx.x;
 
             const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
-            const float * x_dmf = (const float *) x_dm;
 
             int u[2*VDR_Q4_0_Q8_1_MMQ];
 
 #pragma unroll
             for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
-                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
+                u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
             }
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
-                (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
-                y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
+                y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
@@ -203,10 +217,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
 
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -221,13 +238,13 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
 
 #pragma unroll
             for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
-                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
+                u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
             }
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
-                (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
-                y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
+                y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
@@ -293,10 +310,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
 
+    const float * x_dmf = (const float *) x_dm;
+    const int   * y_qs  = (const int   *) y + 4;
+    const float * y_df  = (const float *) y;
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -306,20 +327,18 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
             const int i = i0 + threadIdx.x;
 
             const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
-            const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
-            const float * x_dmf = (const float *) x_dm;
-            const float * y_df  = (const float *) y_ds;
+            const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
 
             int u[2*VDR_Q5_0_Q8_1_MMQ];
 
 #pragma unroll
             for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
-                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
+                u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE];
             }
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
-                (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
@@ -383,10 +402,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
 
+    const int   * y_qs  = (const int   *) y + 4;
+    const half2 * y_ds  = (const half2 *) y;
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -396,18 +418,18 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
             const int i = i0 + threadIdx.x;
 
             const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
-            const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1;
+            const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1;
 
             int u[2*VDR_Q5_1_Q8_1_MMQ];
 
 #pragma unroll
             for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
-                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
-                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
+                u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE];
             }
 
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
-                (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
         }
     }
 }
@@ -455,10 +477,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
     GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
 
+    const float * x_dmf = (const float *) x_dm;
+    const int   * y_qs  = (const int   *) y + 4;
+    const float * y_df  = (const float *) y;
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -467,12 +493,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const float * x_dmf = (const float *) x_dm;
-            const float * y_df  = (const float *) y_ds;
-
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
-                (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
-                y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]);
+                (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
+                y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
         }
     }
 }
@@ -531,10 +554,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
     GGML_UNUSED(x_qh);
 
+    const int   * y_qs  = (const int   *) y + 4;
+    const float * y_df  = (const float *) y;
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -545,11 +571,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
 
             const int kbx = k0 / QI2_K;
             const int ky  = (k0 % QI2_K) * QR2_K;
-            const float * y_df = (const float *) y_ds;
 
             int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
 
-            const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
+            const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
             const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
 
 #pragma unroll
@@ -557,11 +582,11 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
                 v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
             }
 
-            const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
+            const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
 
-            const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE;
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
-                v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
+                v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales,
+                x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
         }
     }
 }
@@ -646,7 +671,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
+
+    const float * x_dmf = (const float *) x_dm;
+    const int   * y_qs  = (const int   *) y + 4;
+    const float * y_df  = (const float *) y;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -658,8 +687,6 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
 
             const int kbx  = k0 / QI3_K;
             const int ky  = (k0 % QI3_K) * QR3_K;
-            const float * x_dmf = (const float *) x_dm;
-            const float * y_df  = (const float *) y_ds;
 
             const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
 
@@ -667,19 +694,19 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
 
 #pragma unroll
             for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
-                const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
+                const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
                 const int shift = 2 * ((ky % 32) / 8);
                 const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
 
-                const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
+                const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
                 const int vlh = (vh << 2) & 0x04040404;
 
                 v[l] = __vsubss4(vll, vlh);
             }
 
-            const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE;
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
-                v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
+                v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
+                x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
         }
     }
 }
@@ -746,10 +773,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
     GGML_UNUSED(x_qh);
 
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -760,9 +790,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
 
             const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
 
-            const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE;
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
-                &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
+                &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
+                x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
         }
     }
 }
@@ -842,10 +872,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
     GGML_UNUSED(x_qh);
 
+    const int   * y_qs  = (const int   *) y + 4;
+    const half2 * y_ds  = (const half2 *) y;
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -856,10 +889,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
 
             const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
 
-            const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k0;
-            const int index_y = j * WARP_SIZE             + (QR5_K*k0) % WARP_SIZE;
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
-                &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
+                &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
+                x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
         }
     }
 }
@@ -932,10 +964,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
-    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+    const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
 
     GGML_UNUSED(x_qh);
 
+    const float * x_dmf = (const float *) x_dm;
+    const int   * y_qs  = (const int   *) y + 4;
+    const float * y_df  = (const float *) y;
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -944,15 +980,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const float * x_dmf = (const float *) x_dm;
-            const float * y_df  = (const float *) y_ds;
-
             const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
 
-            const int index_x = i * (QR6_K*WARP_SIZE + 1) +  QR6_K*k0;
-            const int index_y = j * WARP_SIZE             + (QR6_K*k0) % WARP_SIZE;
             sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
-                &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
+                &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
+                x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
         }
     }
 }
@@ -964,7 +996,6 @@ struct mmq_type_traits;
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
-    static constexpr bool             need_sum   = true;
     static constexpr int              vdr        = VDR_Q4_0_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -972,7 +1003,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
-    static constexpr bool             need_sum   = true;
     static constexpr int              vdr        = VDR_Q4_1_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -980,7 +1010,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
-    static constexpr bool             need_sum   = false;
     static constexpr int              vdr        = VDR_Q5_0_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -988,7 +1017,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
-    static constexpr bool             need_sum   = true;
     static constexpr int              vdr        = VDR_Q5_1_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -996,7 +1024,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
-    static constexpr bool             need_sum   = false;
     static constexpr int              vdr        = VDR_Q8_0_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -1004,7 +1031,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
-    static constexpr bool             need_sum   = false;
     static constexpr int              vdr        = VDR_Q2_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -1012,7 +1038,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
-    static constexpr bool             need_sum   = false;
     static constexpr int              vdr        = VDR_Q3_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -1020,7 +1045,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
-    static constexpr bool             need_sum   = true;
     static constexpr int              vdr        = VDR_Q4_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -1028,7 +1052,6 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
-    static constexpr bool             need_sum   = true;
     static constexpr int              vdr        = VDR_Q5_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
@@ -1036,12 +1059,36 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
 
 template <int mmq_x, int mmq_y, int nwarps, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
-    static constexpr bool             need_sum   = false;
     static constexpr int              vdr        = VDR_Q6_K_Q8_1_MMQ;
     static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
 };
 
+static int mmq_need_sum(const ggml_type type_x) {
+    switch (type_x) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            return true;
+        case GGML_TYPE_Q5_0:
+            return false;
+        case GGML_TYPE_Q5_1:
+            return true;
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+            return false;
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+            return true;
+        case GGML_TYPE_Q6_K:
+            return false;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+    return false;
+}
+
 template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #if defined(RDNA3) || defined(RDNA2)
@@ -1056,7 +1103,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 static __global__ void mul_mat_q(
     const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
-    const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) {
+    const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
 
     // Skip unused template specializations for faster compilation:
     if (mmq_x > get_mmq_x_max_device()) {
@@ -1068,7 +1115,6 @@ static __global__ void mul_mat_q(
     constexpr int              qr         = ggml_cuda_type_traits<type>::qr;
     constexpr int              qi         = ggml_cuda_type_traits<type>::qi;
     constexpr int              mmq_y      = get_mmq_y_device(mmq_x);
-    constexpr bool             need_sum   = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::need_sum;
     constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
     constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
@@ -1080,62 +1126,38 @@ static __global__ void mul_mat_q(
     half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
     int   * tile_x_qh = (int   *) (tile_x_dm + txs.dm);
     int   * tile_x_sc = (int   *) (tile_x_qh + txs.qh);
-    int   * tile_y_qs = (int   *) (tile_x_sc + txs.sc);          // [mmq_x * WARP_SIZE]
-    half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1];
-
-    const block_q8_1 * y = (const block_q8_1 *) yc;
+    int   * tile_y    = (int   *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
 
     const int blocks_per_row_x = ne00 / qk;
-    const int blocks_per_col_y = ne10 / QK8_1;
     const int blocks_per_warp = WARP_SIZE / qi;
 
     const int & ne1 = ne11;
 
     const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
 
+    const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
+
     float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
 
     for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
 
-        load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00);
+        load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
 
 #pragma unroll
         for (int kr = 0; kr < qr; ++kr) {
-            const int kqs = kr*WARP_SIZE + threadIdx.x;
-            const int kbxd = kqs / QI8_1;
-
+            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_x; i0 += nwarps) {
-                const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses
-
-                const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd];
+            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
+                int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
 
-                const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE;
-                tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
-            }
-
-#pragma unroll
-            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
-                const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
-                const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
-                const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1);
-
-                // if the sum is not needed it's faster to transform the scale to f32 ahead of time
-                const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds;
-                half2       * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
-                if (need_sum) {
-                    *dsi_dst = *dsi_src;
-                } else {
-                    float * dfi_dst = (float *) dsi_dst;
-                    *dfi_dst = __low2float(*dsi_src);
-                }
+                tile_y[l] = by0[l];
             }
 
             __syncthreads();
 
 // #pragma unroll // unrolling this loop causes too much register pressure
             for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
-                vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0);
+                vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0);
             }
 
             __syncthreads();
@@ -1165,8 +1187,8 @@ static __global__ void mul_mat_q(
 
 struct mmq_args {
     const char * x; const char * y; float * dst;
-    int64_t ne00; int64_t ne01; int64_t stride00;
-    int64_t ne10; int64_t ne11;
+    int64_t ne00; int64_t ne01; int64_t stride01;
+    int64_t ne10; int64_t ne11; int64_t stride11;
     int64_t ne0;
 };
 
@@ -1184,7 +1206,7 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
     const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
     const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
     const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
-    const int shmem = shmem_x + shmem_y;
+    const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
 
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
     static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1198,11 +1220,11 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
     if (args.ne01 % mmq_y == 0) {
         const bool need_check = false;
         mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
-            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
+            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
     } else {
         const bool need_check = true;
         mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
-            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
+            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
     }
 }
 
index 7578c4b6c7cab7553914f3c84b45ffcdff87664b..b4678682238d31c1007071b30cd1f79f21dba398 100644 (file)
@@ -1,22 +1,23 @@
 #include "quantize.cuh"
+#include <cstdint>
 
-static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) {
-    const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
+    const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (ix >= kx_padded) {
+    if (ix0 >= kx0_padded) {
         return;
     }
 
-    const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y;
+    const int64_t ix1 = blockIdx.y;
 
-    const int64_t i_padded = (int64_t)iy*kx_padded + ix;
+    const int64_t i_padded = ix1*kx0_padded + ix0;
 
     block_q8_1 * y = (block_q8_1 *) vy;
 
     const int64_t ib = i_padded / QK8_1; // block index
     const int64_t iqs = i_padded % QK8_1; // quant index
 
-    const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
+    const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
     float amax = fabsf(xi);
     float sum = xi;
 
@@ -36,10 +37,76 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
     reinterpret_cast<half&>(y[ib].ds.y) = sum;
 }
 
-void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) {
-    const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
-    const dim3 num_blocks(block_num_x, ky, 1);
+template <bool need_sum>
+static __global__ void quantize_mmq_q8_1(
+    const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
+
+    const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (ix0 >= kx0_padded) {
+        return;
+    }
+
+    const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
+
+    block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
+
+    const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
+    const int64_t ib  = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y;              // block index in channel
+    const int64_t iqs = ix0 % (4*QK8_1);                                       // quant index in block
+
+    const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
+    float amax = fabsf(xi);
+
+    amax = warp_reduce_max(amax);
+
+    float sum;
+    if (need_sum) {
+        sum = warp_reduce_sum(xi);
+    }
+
+    const float d = amax / 127;
+    const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+
+    y[ib].qs[iqs] = q;
+
+    if (iqs % QK8_1 != 0) {
+        return;
+    }
+
+    if (need_sum) {
+        y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
+    } else {
+        ((float *) y[ib].ds)[iqs/QK8_1] = d;
+    }
+}
+
+void quantize_row_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
+    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+
+    GGML_ASSERT(kx0_padded % QK8_1 == 0);
+
+    const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const dim3 num_blocks(block_num_x, kx1*channels, 1);
     const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
-    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
+    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
+
+    GGML_UNUSED(type_x);
 }
 
+void quantize_mmq_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
+    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+
+    GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
+
+    const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const dim3 num_blocks(block_num_x, kx1, channels);
+    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
+    if (mmq_need_sum(type_x)) {
+        quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+    } else {
+        quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+    }
+}
index b37a4752f2d24e2de0b1241511f8085721ffc5f2..486c9360a46fdc782614e052e2524cf12a27eb7a 100644 (file)
@@ -1,5 +1,20 @@
+#pragma once
+
 #include "common.cuh"
+#include "mmq.cuh"
+
+#include <cstdint>
 
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
 
-void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream);
+typedef void (*quantize_cuda_t)(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);
+
+void quantize_row_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);
+
+void quantize_mmq_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);