]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: stream-k decomposition for MMQ (#8018)
authorJohannes Gäßler <redacted>
Thu, 20 Jun 2024 12:39:21 +0000 (14:39 +0200)
committerGitHub <redacted>
Thu, 20 Jun 2024 12:39:21 +0000 (14:39 +0200)
* CUDA: stream-k decomposition for MMQ

* fix undefined memory reads for small matrices

ggml-cuda.cu
ggml-cuda/common.cuh
ggml-cuda/mmq.cu
ggml-cuda/mmq.cuh

index b8298ab205e6009b3e4313ffcb2ed2c2a6efd132..f914efd71266572fccba61c89654aa6738378b03 100644 (file)
@@ -635,7 +635,7 @@ static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> &
         }
 
         const int cc = ggml_cuda_info().devices[id].cc;
-        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
+        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
     }
     return row_rounding;
 }
index de7c2e4349ede23b6365e3abfee4f2b75d9dd1e8..5bd24ebe5fa79d83f075754e910f3372c030385a 100644 (file)
@@ -652,8 +652,8 @@ static int get_mmq_x_max_host(const int cc) {
 }
 
 // Round rows to this value for --split-mode row:
-static int get_mmq_y_host(const int cc, const int mmq_x) {
-    return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
+static int get_mmq_y_host(const int cc) {
+    return cc >= CC_VOLTA ? 128 : 64;
 }
 
 //////////////////////
index 1d6b9e6982b6e3cb78fc9ced9fbd9c01d47a7c75..6dbd85feff2fa2e44e74e6da985deddb8fcbe44b 100644 (file)
@@ -30,34 +30,34 @@ void ggml_cuda_op_mul_mat_q(
 
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            mul_mat_q_case<GGML_TYPE_Q4_0>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
             break;
         case GGML_TYPE_Q4_1:
-            mul_mat_q_case<GGML_TYPE_Q4_1>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
             break;
         case GGML_TYPE_Q5_0:
-            mul_mat_q_case<GGML_TYPE_Q5_0>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
             break;
         case GGML_TYPE_Q5_1:
-            mul_mat_q_case<GGML_TYPE_Q5_1>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
             break;
         case GGML_TYPE_Q8_0:
-            mul_mat_q_case<GGML_TYPE_Q8_0>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
             break;
         case GGML_TYPE_Q2_K:
-            mul_mat_q_case<GGML_TYPE_Q2_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q3_K:
-            mul_mat_q_case<GGML_TYPE_Q3_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q4_K:
-            mul_mat_q_case<GGML_TYPE_Q4_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q5_K:
-            mul_mat_q_case<GGML_TYPE_Q5_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q6_K:
-            mul_mat_q_case<GGML_TYPE_Q6_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
             break;
         default:
             GGML_ASSERT(false);
index 6d57974fb4e7c45c6997a9a511bfd5d8d6af8f48..e2d07c20253ae3529ea24c71a61ca095cb68edfc 100644 (file)
@@ -8,6 +8,7 @@
 #include <cstdint>
 
 #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+#define MMQ_NWARPS 8
 
 typedef void (*load_tiles_mmq_t)(
     const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
@@ -15,7 +16,7 @@ typedef void (*load_tiles_mmq_t)(
 typedef void (*vec_dot_mmq_t)(
     const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
     const int * __restrict__ y, float * __restrict__ sum, const int & k0);
-typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
+typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
 
 struct block_q8_1_mmq {
     half2  ds[4];
@@ -50,21 +51,17 @@ static constexpr __device__ int get_mmq_x_max_device() {
 
 // get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
 
+static constexpr __device__ int get_mmq_y_device() {
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-static constexpr __device__ int get_mmq_y_device(int mmq_x) {
-    return mmq_x >= 32 ? 128 : 64;
-}
+    return 128;
 #else
 #if __CUDA_ARCH__ >= CC_VOLTA
-static constexpr __device__ int get_mmq_y_device(int mmq_x) {
-    return mmq_x >= 32 ? 128 : 64;
-}
+    return 128;
 #else
-static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
     return 64;
-}
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
 
 #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
 #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
@@ -1734,30 +1731,34 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 }
 
 template<int mmq_x, int mmq_y, int nwarps, bool need_check>
-static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
+static __device__ __forceinline__ void mmq_write_back_dp4a(
+    const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
-        const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
+        const int j = j0 + threadIdx.y;
 
-        if (j >= ne1) {
+        if (j > j_max) {
             return;
         }
 
 #pragma unroll
         for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
-            const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
+            const int i = i0 + threadIdx.x;
 
-            if (need_check && i >= ne0) {
+            if (need_check && i > i_max) {
                 continue;
             }
 
-            dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+            dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
         }
     }
 }
 
 template<int mmq_x, int mmq_y, int nwarps, bool need_check>
-static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
+static __device__ __forceinline__ void mmq_write_back_mma(
+    const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
     typedef mma_int_C_I16J8 mma_C;
 
     const int i0 = threadIdx.y*mma_C::I;
@@ -1769,19 +1770,19 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri
     for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
 #pragma unroll
         for (int l = 0; l < mma_C::ne; ++l) {
-            const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
+            const int j = j0 + mma_C::get_j(l);
 
-            if (j >= ne1) {
+            if (j > j_max) {
                 continue;
             }
 
-            const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
+            const int i = i0 + mma_C::get_i(l);
 
-            if (need_check && i >= ne0) {
+            if (need_check && i > i_max) {
                 continue;
             }
 
-            dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
+            dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
         }
     }
 }
@@ -1896,32 +1897,16 @@ static bool mmq_need_sum(const ggml_type type_x) {
     return false;
 }
 
-template <ggml_type type, int mmq_x, int nwarps, bool need_check>
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
-    __launch_bounds__(WARP_SIZE*nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
-#else
-#if __CUDA_ARCH__ >= CC_VOLTA
-    __launch_bounds__(WARP_SIZE*nwarps, 1)
-#else
-    __launch_bounds__(WARP_SIZE*nwarps, 2)
-#endif // __CUDA_ARCH__ >= CC_VOLTA
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-static __global__ void mul_mat_q(
-    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
-    const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
-
-    // Skip unused template specializations for faster compilation:
-    if (mmq_x > get_mmq_x_max_device()) {
-        NO_DEVICE_CODE;
-        return;
-    }
+template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
+static __device__ void mul_mat_q_process_tile(
+    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+    const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
+    const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
 
     constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
     constexpr int              qr         = ggml_cuda_type_traits<type>::qr;
     constexpr int              qi         = ggml_cuda_type_traits<type>::qi;
-    constexpr int              mmq_y      = get_mmq_y_device(mmq_x);
+    constexpr int              mmq_y      = get_mmq_y_device();
     constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
     constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
 
@@ -1941,20 +1926,18 @@ static __global__ void mul_mat_q(
     int   * tile_x_sc = (int   *) (tile_x_dm + txs.dm);
     int   * tile_y    = (int   *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
 
-    const int blocks_per_row_x = ne00 / qk;
-    const int blocks_per_warp = WARP_SIZE / qi;
+    constexpr int blocks_per_warp = WARP_SIZE / qi;
 
-    const int & ne1 = ne11;
-
-    const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
+    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
 
-    const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
+    const int tile_x_max_i = ne01 - it*mmq_y - 1;
+    const int tile_y_max_j = ne11 - jt*mmq_x - 1;
 
-    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+    const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
 
-    for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
+    for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
 
-        load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
+        load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
 
 #pragma unroll
         for (int kr = 0; kr < qr; ++kr) {
@@ -1977,7 +1960,176 @@ static __global__ void mul_mat_q(
         }
     }
 
-    write_back(sum, dst, ne0, ne1);
+    if (fixup) {
+        write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
+    } else {
+        write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
+    }
+}
+
+
+// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+    __launch_bounds__(WARP_SIZE*nwarps, 1)
+#else
+    __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+static __global__ void mul_mat_q(
+    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+    const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
+
+    // Skip unused template specializations for faster compilation:
+    if (mmq_x > get_mmq_x_max_device()) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    constexpr int qk    = ggml_cuda_type_traits<type>::qk;
+    constexpr int qi    = ggml_cuda_type_traits<type>::qi;
+    constexpr int mmq_y = get_mmq_y_device();
+
+    // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
+#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+    {
+        constexpr bool fixup = false;
+        mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+            (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+                blockIdx.x, blockIdx.y, 0, ne00/qk);
+        return;
+    }
+#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+
+    const     int64_t blocks_per_ne00 = ne00 / qk;
+    constexpr int     blocks_per_warp = WARP_SIZE / qi;
+
+    const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
+    const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
+
+    // kbc == k block continuous, current index in continuous ijk space.
+    int64_t       kbc      = GGML_PAD((int64_t) blockIdx.x     *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
+    const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
+
+    // kb0 == k index when doing the matrix multiplication for an output tile.
+    int kb0_start = kbc % blocks_per_ne00;
+    int kb0_stop  = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
+    while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
+        const int jt =  kbc /    (blocks_per_ne00*nty);                    // j index of current tile.
+        const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
+
+        constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+        mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+            (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+             it, jt, kb0_start, kb0_stop);
+
+        kbc += blocks_per_ne00;
+        kbc -= kbc % blocks_per_ne00;
+
+        kb0_start = 0;
+        kb0_stop  = min(blocks_per_ne00, kbc_stop - kbc);
+    }
+
+    if (kbc >= kbc_stop) {
+        return;
+    }
+
+    const int jt =  kbc /    (blocks_per_ne00*nty);
+    const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+    constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
+    mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+        (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+            it, jt, kb0_start, kb0_stop);
+}
+
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+static __global__ void mul_mat_q_stream_k_fixup(
+    float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
+
+    constexpr int     mmq_y           = get_mmq_y_device();
+    constexpr int     qk              = ggml_cuda_type_traits<type>::qk;
+    constexpr int     qi              = ggml_cuda_type_traits<type>::qi;
+    constexpr int     blocks_per_warp = WARP_SIZE / qi;
+    const     int64_t blocks_per_ne00 = ne00 / qk;
+
+    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+
+    const int ntx = (ne11 + mmq_x - 1) / mmq_x;
+    const int nty = (ne01 + mmq_y - 1) / mmq_y;
+
+    bool any_fixup = false;
+
+    const int bidx_start = (blockIdx.y*nty + blockIdx.x)     * block_num_mmq / (gridDim.y*gridDim.x);
+    const int bidx_stop  = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
+
+    for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
+        const int64_t kbc      = GGML_PAD((int64_t) bidx     *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
+        const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
+
+        // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
+        if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
+            continue;
+        }
+
+        const int jt =  kbc_stop /    (blocks_per_ne00*nty);
+        const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+        // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
+        if (it != blockIdx.x || jt != blockIdx.y) {
+            continue;
+        }
+
+        any_fixup = true;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
+            }
+        }
+    }
+
+    if (!any_fixup) {
+        return;
+    }
+
+    dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
+
+    const int i_max = ne01 - blockIdx.x*mmq_y - 1;
+    const int j_max = ne11 - blockIdx.y*mmq_x - 1;
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j > j_max) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            if (need_check && i > i_max) {
+                continue;
+            }
+
+            dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+        }
+    }
 }
 
 struct mmq_args {
@@ -1987,124 +2139,151 @@ struct mmq_args {
     int64_t ne0;
 };
 
-constexpr int mmq_get_nwarps(int mmq_x) {
-    return mmq_x >= 32 ? 8 : 4;
-}
-
 static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
     const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
-    const int nwarps = mmq_get_nwarps(mmq_x);
 
     const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
     const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
-    return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
+    return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
 }
 
-template <ggml_type type, int mmq_x, int nwarps>
-static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
+template <ggml_type type, int mmq_x>
+static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
     const int id = ggml_cuda_get_device();
     const int cc = ggml_cuda_info().devices[id].cc;
-    const int mmq_y = get_mmq_y_host(cc, mmq_x);
+    const int nsm = ggml_cuda_info().devices[id].nsm;
+    const int mmq_y = get_mmq_y_host(cc);
 
-    const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
-    const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
-    const dim3 block_nums(block_num_x, block_num_y, 1);
-    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+    const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
 
     const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
 
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
     static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
     if (!shmem_limit_raised[id]) {
-        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
-        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
         shmem_limit_raised[id] = true;
     }
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 
+    const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
+    const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
+    const dim3 block_nums_xy_tiling(nty, ntx, 1);
+
+    const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
+    if (!use_stream_k) {
+        if (args.ne01 % mmq_y == 0) {
+            constexpr bool need_check = false;
+            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
+                (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        } else {
+            constexpr bool need_check = true;
+            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
+                (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        }
+        return;
+    }
+
+    const dim3 block_nums_mmq(nsm, 1, 1);
+
+    ggml_cuda_pool & pool = ctx.pool();
+    ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
+
     if (args.ne01 % mmq_y == 0) {
-        const bool need_check = false;
-        mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
-            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        constexpr bool need_check = false;
+
+        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
+            (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
+            (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
     } else {
-        const bool need_check = true;
-        mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
-            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        constexpr bool need_check = true;
+
+        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
+            (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
+            (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
     }
 }
 
 template <ggml_type type>
-void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
+void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
     const int id    = ggml_cuda_get_device();
     const int nsm   = ggml_cuda_info().devices[id].nsm;
     const int cc    = ggml_cuda_info().devices[id].cc;
     const int smpbo = ggml_cuda_info().devices[id].smpbo;
 
     const int mmq_x_max = get_mmq_x_max_host(cc);
-    const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
+    const int mmq_y = get_mmq_y_host(cc);
     const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
+    const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
 
     int mmq_x_best  = 0;
-    int nwaves_best = INT_MAX;
+    int nparts_best = INT_MAX;
+
+    for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
+        const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
+        const int nwaves_xy_tiling = ntiles_x*block_num_y;
 
-    for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
-        const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
-        const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
+        const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
 
-        if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
+        if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
             mmq_x_best  = mmq_x;
-            nwaves_best = nwaves;
+            nparts_best = nparts;
         }
     }
 
     switch (mmq_x_best) {
         case   8:
-            launch_mul_mat_q<type,   8, mmq_get_nwarps(  8)>(args, stream);
+            launch_mul_mat_q<type,   8>(ctx, args, stream);
             break;
         case  16:
-            launch_mul_mat_q<type,  16, mmq_get_nwarps( 16)>(args, stream);
+            launch_mul_mat_q<type,  16>(ctx, args, stream);
             break;
         case  24:
-            launch_mul_mat_q<type,  24, mmq_get_nwarps( 24)>(args, stream);
+            launch_mul_mat_q<type,  24>(ctx, args, stream);
             break;
         case  32:
-            launch_mul_mat_q<type,  32, mmq_get_nwarps( 32)>(args, stream);
+            launch_mul_mat_q<type,  32>(ctx, args, stream);
             break;
         case  40:
-            launch_mul_mat_q<type,  40, mmq_get_nwarps( 40)>(args, stream);
+            launch_mul_mat_q<type,  40>(ctx, args, stream);
             break;
         case  48:
-            launch_mul_mat_q<type,  48, mmq_get_nwarps( 48)>(args, stream);
+            launch_mul_mat_q<type,  48>(ctx, args, stream);
             break;
         case  56:
-            launch_mul_mat_q<type,  56, mmq_get_nwarps( 56)>(args, stream);
+            launch_mul_mat_q<type,  56>(ctx, args, stream);
             break;
         case  64:
-            launch_mul_mat_q<type,  64, mmq_get_nwarps( 64)>(args, stream);
+            launch_mul_mat_q<type,  64>(ctx, args, stream);
             break;
         case  72:
-            launch_mul_mat_q<type,  72, mmq_get_nwarps( 72)>(args, stream);
+            launch_mul_mat_q<type,  72>(ctx, args, stream);
             break;
         case  80:
-            launch_mul_mat_q<type,  80, mmq_get_nwarps( 80)>(args, stream);
+            launch_mul_mat_q<type,  80>(ctx, args, stream);
             break;
         case  88:
-            launch_mul_mat_q<type,  88, mmq_get_nwarps( 88)>(args, stream);
+            launch_mul_mat_q<type,  88>(ctx, args, stream);
             break;
         case  96:
-            launch_mul_mat_q<type,  96, mmq_get_nwarps( 96)>(args, stream);
+            launch_mul_mat_q<type,  96>(ctx, args, stream);
             break;
         case 104:
-            launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream);
+            launch_mul_mat_q<type, 104>(ctx, args, stream);
             break;
         case 112:
-            launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream);
+            launch_mul_mat_q<type, 112>(ctx, args, stream);
             break;
         case 120:
-            launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream);
+            launch_mul_mat_q<type, 120>(ctx, args, stream);
             break;
         case 128:
-            launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream);
+            launch_mul_mat_q<type, 128>(ctx, args, stream);
             break;
         default:
             fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
@@ -2114,7 +2293,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
 }
 
 #define DECL_MMQ_CASE(type)                                                        \
-    template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
+    template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
 
 extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
 extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);