]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: batched+noncont MMQ, refactor bs>1 MoE code (#13199)
authorJohannes Gäßler <redacted>
Wed, 30 Apr 2025 21:12:59 +0000 (23:12 +0200)
committerGitHub <redacted>
Wed, 30 Apr 2025 21:12:59 +0000 (23:12 +0200)
ggml/src/ggml-cuda/getrows.cu
ggml/src/ggml-cuda/getrows.cuh
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/mmvq.cu
ggml/src/ggml-cuda/quantize.cu
ggml/src/ggml-cuda/quantize.cuh
tests/test-backend-ops.cpp

index 4cef53a98cfd6acf87175eba687eb1ea16e89e2b..ea8bf691609966ced806b0682631e6ad486eedbf 100644 (file)
@@ -33,8 +33,8 @@ static __global__ void k_get_rows(
     dfloat2 v;
     dequantize_kernel(src0_row, ib, iqs, v);
 
-    dst_row[iybs + iqs + 0]        = v.x;
-    dst_row[iybs + iqs + y_offset] = v.y;
+    dst_row[iybs + iqs + 0]        = float(v.x);
+    dst_row[iybs + iqs + y_offset] = float(v.y);
 }
 
 template<typename src0_t, typename dst_t>
@@ -60,7 +60,7 @@ static __global__ void k_get_rows_float(
     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
     const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
 
-    dst_row[i00] = src0_row[i00];
+    dst_row[i00] = float(src0_row[i00]);
 }
 
 template<typename grad_t, typename dst_t>
@@ -86,122 +86,161 @@ static __global__ void k_get_rows_back_float(
     dst[dst_row*ncols + col] = sum;
 }
 
-template<int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_cuda(
-        const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-        const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
+template<int qk, int qr, dequantize_kernel_t dq, typename dst_t>
+static void get_rows_cuda_q(
+        const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
+        const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
+        const size_t nb1, const size_t nb2, const size_t nb3,
+        cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
     const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
     const dim3 block_nums(block_num_x, ne10, ne11*ne12);
 
     // strides in elements
-    //const size_t s0 = nb0 / ggml_element_size(dst);
-    const size_t s1 = nb1 / ggml_element_size(dst);
-    const size_t s2 = nb2 / ggml_element_size(dst);
-    const size_t s3 = nb3 / ggml_element_size(dst);
+    // const size_t s0 = nb0 / sizeof(dst_t);
+    const size_t s1 = nb1 / sizeof(dst_t);
+    const size_t s2 = nb2 / sizeof(dst_t);
+    const size_t s3 = nb3 / sizeof(dst_t);
 
-    const size_t s10 = nb10 / ggml_element_size(src1);
-    const size_t s11 = nb11 / ggml_element_size(src1);
-    const size_t s12 = nb12 / ggml_element_size(src1);
-    //const size_t s13 = nb13 / ggml_element_size(src1);
+    const size_t s10 = nb10 / sizeof(int32_t);
+    const size_t s11 = nb11 / sizeof(int32_t);
+    const size_t s12 = nb12 / sizeof(int32_t);
+    // const size_t s13 = nb13 / sizeof(int32_t);
 
     GGML_ASSERT(ne00 % 2 == 0);
 
     k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
-        src0_dd, src1_dd, dst_dd,
+        src0_d, src1_d, dst_d,
         ne00, /*ne01, ne02, ne03,*/
         /*ne10, ne11,*/ ne12, /*ne13,*/
         /* s0,*/ s1, s2, s3,
         /* nb00,*/ nb01, nb02, nb03,
         s10, s11, s12/*, s13*/);
-
-    GGML_UNUSED(dst);
 }
 
-template<typename src0_t>
+template<typename src0_t, typename dst_t>
 static void get_rows_cuda_float(
-        const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-        const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
-
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    GGML_ASSERT(ne13 == 1);
-
+        const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d,
+        const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
+        const size_t nb1, const size_t nb2, const size_t nb3,
+        cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
     const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
     const dim3 block_nums(block_num_x, ne10, ne11*ne12);
 
     // strides in elements
-    //const size_t s0 = nb0 / ggml_element_size(dst);
-    const size_t s1 = nb1 / ggml_element_size(dst);
-    const size_t s2 = nb2 / ggml_element_size(dst);
-    const size_t s3 = nb3 / ggml_element_size(dst);
+    // const size_t s0 = nb0 / sizeof(dst_t);
+    const size_t s1 = nb1 / sizeof(dst_t);
+    const size_t s2 = nb2 / sizeof(dst_t);
+    const size_t s3 = nb3 / sizeof(dst_t);
 
-    const size_t s10 = nb10 / ggml_element_size(src1);
-    const size_t s11 = nb11 / ggml_element_size(src1);
-    const size_t s12 = nb12 / ggml_element_size(src1);
-    //const size_t s13 = nb13 / ggml_element_size(src1);
+    const size_t s10 = nb10 / sizeof(int32_t);
+    const size_t s11 = nb11 / sizeof(int32_t);
+    const size_t s12 = nb12 / sizeof(int32_t);
+    // const size_t s13 = nb13 / sizeof(int32_t);
 
     k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
-        src0_dd, src1_dd, dst_dd,
+        src0_d, src1_d, dst_d,
         ne00, /*ne01, ne02, ne03,*/
         /*ne10, ne11,*/ ne12, /*ne13,*/
         /* s0,*/ s1, s2, s3,
         /* nb00,*/ nb01, nb02, nb03,
         s10, s11, s12/*, s13*/);
-
-    GGML_UNUSED(dst);
 }
 
-void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const ggml_tensor * src1 = dst->src[1];
-
-    const void    * src0_d = (const void    *) src0->data;
-    const int32_t * src1_d = (const int32_t *) src1->data;
-    float         * dst_d  = (float         *) dst->data;
-
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
-
-    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
-    GGML_ASSERT(dst->nb[0]  == ggml_type_size(dst->type));
-
-    switch (src0->type) {
+template <typename dst_t>
+static void ggml_cuda_get_rows_switch_src0_type(
+        const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
+        const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
+        const size_t nb1, const size_t nb2, const size_t nb3,
+        cudaStream_t stream) {
+    switch (src0_type) {
         case GGML_TYPE_F16:
-            get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
+            get_rows_cuda_float((const half *) src0_d, src1_d, dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_F32:
-            get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
+            get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+            break;
+        case GGML_TYPE_BF16:
+            get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_Q4_0:
-            get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
+            get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_Q4_1:
-            get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
+            get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_Q5_0:
-            get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
+            get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_Q5_1:
-            get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
+            get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_Q8_0:
-            get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
+            get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
             break;
         default:
             // TODO: k-quants
-            GGML_ABORT("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
+            GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
             break;
     }
 }
 
+void get_rows_cuda(
+        const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
+        int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
+        int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
+        size_t nb1, size_t nb2, size_t nb3,
+        cudaStream_t stream) {
+    switch (dst_type) {
+        case GGML_TYPE_F32:
+            ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+            break;
+        case GGML_TYPE_F16:
+            ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+            break;
+        case GGML_TYPE_BF16:
+            ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d,
+                ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+            break;
+        default:
+            GGML_ABORT("%s: unsupported dst type: %s\n", __func__, ggml_type_name(dst_type));
+            break;
+    }
+}
+
+void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ne13 == 1);
+
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(dst->nb[0]  == ggml_type_size(dst->type));
+
+    get_rows_cuda(src0->data, src0->type, (const int32_t *) src1->data, dst->data, dst->type,
+        ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+}
+
 void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
     const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
index a1ca643f1c5300d1e2f91f111d7f3f76be4abde3..3c5bea5f48c1c71f0ceb9b8a7dad30bfad18d901 100644 (file)
@@ -3,6 +3,13 @@
 #define CUDA_GET_ROWS_BLOCK_SIZE 256
 #define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
 
+void get_rows_cuda(
+        const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
+        int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
+        int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
+        size_t nb1, size_t nb2, size_t nb3,
+        cudaStream_t stream);
+
 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index fba8cb6565bae5c8d4f74e09a2471993b4cddf67..9fb2134f98d3d9b1dc3288b11ec832d196048fcd 100644 (file)
@@ -1551,7 +1551,7 @@ static void ggml_cuda_op_mul_mat(
 
             if (src1_on_device && src1_is_contiguous) {
                 quantize_src1(
-                    dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10,
+                    dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10,
                     nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
                     src1_padded_col_size, ne11, ne12, ne13, stream);
                 CUDA_CHECK(cudaGetLastError());
@@ -1649,7 +1649,7 @@ static void ggml_cuda_op_mul_mat(
 
                 if (quantize_src1 && !src1_is_contiguous) {
                     quantize_src1(
-                        src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
+                        src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
                         src1_padded_col_size, src1_ncols, 1, 1, stream);
                     CUDA_CHECK(cudaGetLastError());
                 }
@@ -1949,6 +1949,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_vec_q) {
         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
+    } else if (!split && use_mul_mat_q) {
+        ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
     } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
             !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
@@ -1964,183 +1966,145 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     }
 }
 
-struct mmid_row_mapping {
-    int32_t i1;
-    int32_t i2;
-};
-
-static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
-                                                 int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
-                                                 const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
-                                                 int64_t ne11, int64_t ne10,
-                                                 size_t nb11, size_t nb12) {
-    int32_t iid1 = blockIdx.x;
-    int32_t id = blockIdx.y;
-
-    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
-
-    if (row_id_i != i02) {
-        return;
-    }
-
-    const int64_t i11 = id % ne11;
-    const int64_t i12 = iid1;
-
-    __shared__ int src1_row;
-    if (threadIdx.x == 0) {
-        src1_row = atomicAdd(cur_src1_row, 1);
-        row_mapping[src1_row] = {id, iid1};
-    }
-    __syncthreads();
-
-    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
-    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
-
-    for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
-        src1_row_contiguous[i] = src1_row_original[i];
-    }
-}
-
-static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
-                                                  const mmid_row_mapping * __restrict__ row_mapping,
-                                                  int64_t ne0,
-                                                  size_t nb1, size_t nb2) {
-    int32_t i = blockIdx.x;
-
-    const int32_t i1 = row_mapping[i].i1;
-    const int32_t i2 = row_mapping[i].i2;
-
-    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
-    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
-
-    for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
-        dst_row_original[j] = dst_row_contiguous[j];
-    }
-}
-
 static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
     const ggml_tensor * ids  = dst->src[2];
 
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ne2 == 1) {
-        if (ggml_is_quantized(src0->type)) {
-            ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
-        } else {
-            ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
-        }
-        return;
-    }
-
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
     GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
 
-    cudaStream_t stream = ctx.stream();
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    const int64_t n_as = ne02;
-    const int64_t n_ids = ids->ne[0];
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
-    std::vector<char> ids_host(ggml_nbytes(ids));
-    const char * ids_dev = (const char *) ids->data;
-    CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
-    CUDA_CHECK(cudaStreamSynchronize(stream));
+    if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        if (ne2 == 1) {
+            if (ggml_is_quantized(src0->type)) {
+                ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
+            } else {
+                ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
+            }
+            return;
+        }
 
-    ggml_tensor src0_row = *src0;
-    ggml_tensor src1_row = *src1;
-    ggml_tensor dst_row  = *dst;
+        if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
+            ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
+            return;
+        }
+    }
 
-    char * src0_original = (char *) src0->data;
-    char * src1_original = (char *) src1->data;
-    char * dst_original  = (char *)  dst->data;
+    cudaStream_t stream = ctx.stream();
 
-    src0_row.ne[2] = 1;
-    src0_row.ne[3] = 1;
-    src0_row.nb[3] = nb02;
+    GGML_ASSERT(nb12 % nb11 == 0);
+    GGML_ASSERT(nb2  % nb1  == 0);
 
-    src1_row.ne[1] = 1;
-    src1_row.ne[2] = 1;
-    src1_row.ne[3] = 1;
-    src1_row.nb[2] = nb11;
-    src1_row.nb[3] = nb11;
+    const ggml_type type_src1_sorted = (src0->type == GGML_TYPE_F16 && !fast_fp16_hardware_available(cc))
+        || ggml_is_quantized(src0->type) ? GGML_TYPE_F32 : src0->type;
+    const ggml_type type_dst_sorted  = GGML_TYPE_F32;
+    const size_t ts_src1_sorted = ggml_type_size(type_src1_sorted);
+    const size_t ts_dst_sorted  = ggml_type_size(type_dst_sorted);
 
-    dst_row.ne[1] = 1;
-    dst_row.ne[2] = 1;
-    dst_row.ne[3] = 1;
-    dst_row.nb[2] = nb1;
-    dst_row.nb[3] = nb1;
+    const int64_t n_expert_used = ids->ne[0];
+    const int64_t ne_get_rows = ne12 * n_expert_used;
 
-    ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
-    ggml_cuda_pool_alloc<char>  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
+    std::vector<int32_t> ids_to_sorted_host;
+    ids_to_sorted_host.reserve(2*ne_get_rows);
+    std::vector<int32_t> ids_from_sorted_host(ne_get_rows);
 
-    src1_row.data = src1_contiguous.get();
-    dst_row.data  =  dst_contiguous.get();
+    ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool(), 2*ne_get_rows);
 
-    for (int64_t i02 = 0; i02 < n_as; i02++) {
-        int64_t num_src1_rows = 0;
+    std::vector<int32_t> tokens_per_expert(ne02);
 
-        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-            for (int64_t id = 0; id < n_ids; id++) {
-                const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+    ggml_cuda_pool_alloc<char> src1_sorted(ctx.pool(), ne12*n_expert_used*ne10*ts_src1_sorted);
+    ggml_cuda_pool_alloc<char>  dst_sorted(ctx.pool(), ne2 *n_expert_used* ne0*ts_dst_sorted);
 
-                GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+    std::vector<char> ids_host(ggml_nbytes(ids));
+    CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
+    CUDA_CHECK(cudaStreamSynchronize(stream));
 
-                if (row_id_i != i02) {
-                    continue;
+    for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
+        for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
+            for (int64_t iex = 0; iex < n_expert_used; ++iex) {
+                const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
+                assert(expert_to_use >= 0 && expert_to_use < ne02);
+                if (expert_to_use == i02) {
+                    ids_from_sorted_host[i12*n_expert_used + iex] = ids_to_sorted_host.size();
+                    ids_to_sorted_host.push_back(i12*ne11 + iex % ne11);
+                    tokens_per_expert[i02]++;
+                    break;
                 }
-
-                num_src1_rows++;
             }
         }
+    }
+    GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows));
 
-        if (num_src1_rows == 0) {
-            continue;
-        }
-
-        ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
-        ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
-        CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
-
-        {
-            dim3 block_dims(std::min((unsigned int)ne10, 768u));
-            dim3 grid_dims(ids->ne[1], n_ids);
-            k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
-                    src1_original, src1_contiguous.get(),
-                    dev_cur_src1_row.get(), dev_row_mapping.get(),
-                    ids_dev, i02, ids->nb[1], ids->nb[0],
-                    ne11, ne10,
-                    nb11, nb12);
-            CUDA_CHECK(cudaGetLastError());
-        }
+    ids_to_sorted_host.insert(ids_to_sorted_host.end(), ids_from_sorted_host.begin(), ids_from_sorted_host.end());
 
-        src0_row.data = src0_original + i02*nb02;
+    CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_to_sorted_host.data(), 2*ne_get_rows*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
+    CUDA_CHECK(cudaStreamSynchronize(stream));
 
-        GGML_ASSERT(nb11 == sizeof(float)*ne10);
-        GGML_ASSERT(nb1 == sizeof(float)*ne0);
+    const int32_t * ids_to_sorted   = ids_buf_dev.ptr + 0*ne_get_rows;
+    const int32_t * ids_from_sorted = ids_buf_dev.ptr + 1*ne_get_rows;
 
-        src1_row.ne[1] = num_src1_rows;
-        src1_row.nb[1] = nb11;
-        src1_row.nb[2] = num_src1_rows*nb11;
-        src1_row.nb[3] = num_src1_rows*nb11;
+    get_rows_cuda(src1->data, src1->type, ids_to_sorted, src1_sorted.ptr, type_src1_sorted,
+        ne10, nb11, nb12, nb13,
+        ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
+        ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream);
+    CUDA_CHECK(cudaGetLastError());
 
-        dst_row.ne[1] = num_src1_rows;
-        dst_row.nb[1] = nb1;
-        dst_row.nb[2] = num_src1_rows*nb1;
-        dst_row.nb[3] = num_src1_rows*nb1;
+    char * src1_data_cur = (char *) src1_sorted.ptr;
+    char *  dst_data_cur = (char *)  dst_sorted.ptr;
+    for (int64_t i02 = 0; i02 < ne02; ++i02) {
+        if (tokens_per_expert[i02] == 0) {
+            continue;
+        }
 
-        ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+        ggml_tensor src0_slice = *src0;
+        src0_slice.ne[2] = 1;
+        src0_slice.nb[3] = src0_slice.nb[2];
+        src0_slice.data  = (char *) src0->data + i02*nb02;
+
+        ggml_tensor src1_slice;
+        memset(&src1_slice, 0, sizeof(src1_slice));
+        src1_slice.buffer = src1->buffer;
+        src1_slice.type   = type_src1_sorted;
+        src1_slice.ne[0]  = ne10;
+        src1_slice.ne[1]  = tokens_per_expert[i02];
+        src1_slice.ne[2]  = 1;
+        src1_slice.ne[3]  = 1;
+        src1_slice.nb[0]  = ts_src1_sorted;
+        src1_slice.nb[1]  = src1_slice.ne[0] * src1_slice.nb[0];
+        src1_slice.nb[2]  = src1_slice.ne[1] * src1_slice.nb[1];
+        src1_slice.nb[3]  = src1_slice.ne[2] * src1_slice.nb[2];
+        src1_slice.data   = src1_data_cur;
+
+        ggml_tensor dst_slice;
+        memset(&dst_slice, 0, sizeof(dst_slice));
+        dst_slice.buffer = dst->buffer;
+        dst_slice.type   = type_dst_sorted;
+        dst_slice.ne[0]  = ne0;
+        dst_slice.ne[1]  = tokens_per_expert[i02];
+        dst_slice.ne[2]  = 1;
+        dst_slice.ne[3]  = 1;
+        dst_slice.nb[0]  = ts_dst_sorted;
+        dst_slice.nb[1]  = dst_slice.ne[0] * dst_slice.nb[0];
+        dst_slice.nb[2]  = dst_slice.ne[1] * dst_slice.nb[1];
+        dst_slice.nb[3]  = dst_slice.ne[2] * dst_slice.nb[2];
+        dst_slice.data   = dst_data_cur;
+
+        ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice);
+        CUDA_CHECK(cudaGetLastError());
 
-        {
-            dim3 block_dims(std::min((unsigned int)ne0, 768u));
-            dim3 grid_dims(num_src1_rows);
-            k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
-                    dst_original, dst_contiguous.get(),
-                    dev_row_mapping.get(),
-                    ne0,
-                    nb1, nb2);
-            CUDA_CHECK(cudaGetLastError());
-        }
+        src1_data_cur += src1_slice.nb[2];
+        dst_data_cur  +=  dst_slice.nb[2];
     }
+
+    get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type,
+        ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted,
+        ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
+        nb1, nb2, nb3, stream);
 }
 
 static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
index b36b43d5417baf099ea2a91c319837faa180cf2d..f397a7e038469648ef8e40e1fddad8bfb97a1378 100644 (file)
@@ -1,37 +1,10 @@
 #include "mmq.cuh"
+#include "quantize.cuh"
 
-void ggml_cuda_op_mul_mat_q(
-    ggml_backend_cuda_context & ctx,
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
-    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, cudaStream_t stream) {
-
-    const int64_t ne00 = src0->ne[0];
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    GGML_ASSERT(ne10 % QK8_1 == 0);
+#include <vector>
 
-    const int64_t ne0 = dst->ne[0];
-
-    const int64_t row_diff = row_high - row_low;
-    const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
-
-    int id = ggml_cuda_get_device();
-    const int cc = ggml_cuda_info().devices[id].cc;
-
-    // the main device has a larger memory buffer to hold the results from all GPUs
-    // nrows_dst == nrows of the matrix that the kernel writes into
-    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
-
-    // The stream-k decomposition is only faster for recent NVIDIA GPUs.
-    // Also its fixup needs to allocate a temporary buffer in the memory pool.
-    // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
-    const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
-        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
-    const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
-
-    switch (src0->type) {
+static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
+    switch (args.type_x) {
         case GGML_TYPE_Q4_0:
             mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
             break;
@@ -90,10 +63,195 @@ void ggml_cuda_op_mul_mat_q(
             GGML_ABORT("fatal error");
             break;
     }
+}
+
+void ggml_cuda_mul_mat_q(
+        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(        dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(!ids || ids->type  == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    cudaStream_t stream = ctx.stream();
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
+
+    GGML_ASSERT(        nb00       == ts_src0);
+    GGML_ASSERT(        nb10       == ts_src1);
+    GGML_ASSERT(        nb0        == ts_dst);
+    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+
+    const char  * src0_d = (const char  *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       *  dst_d = (float       *)  dst->data;
+
+    const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s1  =  dst->nb[1] / ts_dst;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
+
+    const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
+
+    if (!ids) {
+        const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
+            get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
+        ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
+
+        {
+            const int64_t s11 = src1->nb[1] / ts_src1;
+            const int64_t s12 = src1->nb[2] / ts_src1;
+            const int64_t s13 = src1->nb[3] / ts_src1;
+            quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
+                ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
+        }
+
+        const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
+        const int64_t s13 = ne12*s12;
+
+        const mmq_args args = {
+            src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
+            ne00, ne01, ne1, s01, s1,
+            ne02, ne12, s02, s12, s2,
+            ne03, ne13, s03, s13, s3,
+            use_stream_k};
+        ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
+        return;
+    }
+
+    GGML_ASSERT(ne13 == 1);
+    GGML_ASSERT(nb12 % nb11 == 0);
+    GGML_ASSERT(nb2  % nb1  == 0);
+
+    const int64_t n_expert_used = ids->ne[0];
+    const int64_t ne_get_rows = ne12 * n_expert_used;
+
+    std::vector<char> ids_host(ggml_nbytes(ids));
+    std::vector<int32_t> ids_src1_host;
+    ids_src1_host.reserve(ne_get_rows);
+    std::vector<int32_t> ids_dst_host;
+    ids_dst_host.reserve(ne_get_rows);
+    std::vector<int32_t> tokens_per_expert_host(ne02);
+    std::vector<int32_t> expert_bounds_host(ne02 + 1);
+    ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
+
+    CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
+    CUDA_CHECK(cudaStreamSynchronize(stream));
+
+    for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
+        for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
+            for (int64_t iex = 0; iex < n_expert_used; ++iex) {
+                const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
+                assert(expert_to_use >= 0 && expert_to_use < ne02);
+                if (expert_to_use == i02) {
+                    ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
+                    ids_dst_host.push_back(i12*ne1 + iex);
+                    tokens_per_expert_host[i02]++;
+                    break;
+                }
+            }
+        }
+    }
+
+    int32_t cumsum = 0;
+    for (int64_t i = 0; i < ne02; ++i) {
+        expert_bounds_host[i] = cumsum;
+        cumsum += tokens_per_expert_host[i];
+    }
+    expert_bounds_host[ne02] = cumsum;
+
+    std::vector<int32_t> ids_buf_host;
+    ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
+    ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
+    ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
+    ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
+    ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
+    CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
+    CUDA_CHECK(cudaStreamSynchronize(stream));
+
+    const int32_t * ids_src1_dev      = ids_buf_dev.ptr;
+    const int32_t * ids_dst_dev       = ids_src1_dev + ids_src1_host.size();
+    const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
+
+    const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
+        get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
+    ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
+
+    const int64_t ne11_flat = ne12*n_expert_used;
+    const int64_t ne12_flat = 1;
+    const int64_t ne13_flat = 1;
+
+    {
+        const int64_t s11 = src1->nb[1] / ts_src1;
+        const int64_t s12 = src1->nb[2] / ts_src1;
+        const int64_t s13 = src1->nb[2] / ts_src1;
+        quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
+            ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+    }
+
+    const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
+    const int64_t s13 = ne12*s12;
+
+    // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
+    const mmq_args args = {
+        src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
+        ne00, ne01, ne_get_rows, s01, s1,
+        ne02, ne02, s02, s12, s2,
+        ne03, ne13, s03, s13, s3,
+        use_stream_k};
+
+    ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
+}
+
+void ggml_cuda_op_mul_mat_q(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    const int64_t ne00 = src0->ne[0];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    GGML_ASSERT(ne10 % QK8_1 == 0);
+
+    const int64_t ne0 = dst->ne[0];
+
+    const int64_t row_diff = row_high - row_low;
+    const int64_t stride01 = ne00 / ggml_blck_size(src0->type);
+
+    const int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // nrows_dst == nrows of the matrix that the kernel writes into
+    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+    // The stream-k decomposition is only faster for recent NVIDIA GPUs.
+    // Also its fixup needs to allocate a temporary buffer in the memory pool.
+    // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
+    const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
+        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
+    const mmq_args args = {
+        src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
+        ne00, row_diff, src1_ncols, stride01, nrows_dst,
+        1, 1, 0, 0, 0,
+        1, 1, 0, 0, 0,
+        use_stream_k};
+
+    ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
 
     GGML_UNUSED(src1);
     GGML_UNUSED(dst);
     GGML_UNUSED(src1_ddf_i);
+    GGML_UNUSED(src1_padded_row_size);
 }
 
 bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
index 3cb2015520ba14c99eb399d81af71666435c17af..8c93e8326e20be99382ca6df2ea96185720b87b7 100644 (file)
@@ -13,9 +13,10 @@ using namespace ggml_cuda_mma;
 #define MMQ_ITER_K 256
 #define MMQ_NWARPS 8
 
-typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
-typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
-typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
+typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
+typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
+typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted,
+    float * __restrict__ dst, const int stride, const int i_max, const int j_max);
 
 enum mmq_q8_1_ds_layout {
     MMQ_Q8_1_DS_LAYOUT_D4,
@@ -233,7 +234,7 @@ static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */
 // ------------------------------------------------------------
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -289,7 +290,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -328,7 +329,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -384,7 +385,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -423,7 +424,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -495,7 +496,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -565,7 +566,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -621,7 +622,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -651,7 +652,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     typedef tile<16, 8, int> tile_A;
     typedef tile< 8, 8, int> tile_B;
@@ -732,7 +733,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -762,7 +763,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     typedef tile<16, 8, int> tile_A;
     typedef tile< 8, 8, int> tile_B;
@@ -839,7 +840,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
     const int   * x_qs = (const int   *) x;
@@ -871,7 +872,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #ifdef NEW_MMA_AVAILABLE
 
     typedef tile<16, 4, int> tile_A;
@@ -955,7 +956,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -1011,7 +1012,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1074,7 +1075,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #ifdef NEW_MMA_AVAILABLE
 
     typedef tile<16, 4, int> tile_A;
@@ -1201,7 +1202,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -1298,7 +1299,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1340,7 +1341,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -1437,7 +1438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1469,7 +1470,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -1578,7 +1579,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1610,7 +1611,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -1693,7 +1694,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1726,7 +1727,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
 
 template <int mmq_x, int mmq_y, int nwarps>
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
-    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #ifdef NEW_MMA_AVAILABLE
 
     typedef tile<16, 4, int> tile_A;
@@ -1835,7 +1836,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -1893,7 +1894,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -1951,7 +1952,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -2007,7 +2008,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -2070,7 +2071,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -2126,7 +2127,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -2189,7 +2190,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -2245,7 +2246,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 }
 
 template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
-    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
 
 #ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
@@ -2306,8 +2307,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
 template<int mmq_x, int mmq_y, int nwarps, bool need_check>
 static __device__ __forceinline__ void mmq_write_back_dp4a(
-    const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
-
+        const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
+        const int stride, const int i_max, const int j_max) {
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -2324,15 +2325,15 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
                 continue;
             }
 
-            dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+            dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
         }
     }
 }
 
 template<int mmq_x, int mmq_y, int nwarps, bool need_check>
 static __device__ __forceinline__ void mmq_write_back_mma(
-    const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
-
+        const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
+        const int stride, const int i_max, const int j_max) {
     typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
@@ -2362,7 +2363,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
                     continue;
                 }
 
-                dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
+                dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
             }
         }
     }
@@ -2518,17 +2519,18 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
 };
 
 template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
-static __device__ void mul_mat_q_process_tile(
-    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
-    const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
-    const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
+static __device__ __forceinline__ void mul_mat_q_process_tile(
+        const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
+        const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+        const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
+        const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
 
     constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
     constexpr int              mmq_y      = get_mmq_y_device();
     constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
 
-    extern __shared__ char data_mul_mat_q[];
-    int * tile_y = (int *) data_mul_mat_q;
+    extern __shared__ int data_mul_mat_q[];
+    int * tile_y = data_mul_mat_q + mmq_x;
     int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
 
 #ifdef NEW_MMA_AVAILABLE
@@ -2543,16 +2545,11 @@ static __device__ void mul_mat_q_process_tile(
 
     float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
 
-    const int tile_x_max_i = ne01 - it*mmq_y - 1;
-    const int tile_y_max_j = ne11 - jt*mmq_x - 1;
-
-    const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
-
     for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
-        load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
+        load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
 
         {
-            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
+            const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
 #pragma unroll
             for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
                 int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2568,7 +2565,7 @@ static __device__ void mul_mat_q_process_tile(
         __syncthreads();
 
         {
-            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
+            const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
 #pragma unroll
             for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
                 int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2585,12 +2582,10 @@ static __device__ void mul_mat_q_process_tile(
     }
 
     if (fixup) {
-        write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
+        write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
     } else {
-        write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
+        write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j);
     }
-
-    GGML_UNUSED(ne00); GGML_UNUSED(ne10);
 }
 
 
@@ -2609,8 +2604,11 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 static __global__ void mul_mat_q(
-    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
-    const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
+        const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
+        const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+        const int ncols_x, const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
+        const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
 
     // Skip unused template specializations for faster compilation:
     if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -2621,26 +2619,85 @@ static __global__ void mul_mat_q(
     constexpr int qk    = ggml_cuda_type_traits<type>::qk;
     constexpr int mmq_y = get_mmq_y_device();
 
+    const int ntx = (ncols_y + mmq_x - 1) / mmq_x; // Number of tiles x
+    const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
+
+    // Initialize the ids for writing back data with just the index.
+    // For regular matrix multiplications this is never changed.
+    // For MoE the correct indices are loaded from ids_dst.
+    extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
+        const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+        if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+            break;
+        }
+
+        ids_dst_shared[j] = j;
+    }
+
     // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
 #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
     {
+        const int wt = blockIdx.z / nchannels_y;
+        const int zt = blockIdx.z - wt*nchannels_y;
+        const int jt = blockIdx.y;
+        const int it = blockIdx.x;
+
+        // Defaults for regular matrix multiplication:
+        int col_low    = 0;
+        int col_high   = ncols_y;
+        int col_diff   = ncols_y;
+        int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
+        int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
+
+        if (ids_dst) {
+            col_low  = expert_bounds[zt + 0];
+            col_high = expert_bounds[zt + 1];
+            col_diff = col_high - col_low;
+
+            offset_y   = 0;
+            offset_dst = 0;
+
+            if (jt*mmq_x >= col_diff) {
+                return;
+            }
+
+#pragma unroll
+            for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
+                const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+                if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+                    break;
+                }
+
+                ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
+            }
+        }
+
+        offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
+        offset_dst += it*mmq_y;
+
+        const int tile_x_max_i = nrows_x  - it*mmq_y - 1;
+        const int tile_y_max_j = col_diff - jt*mmq_x - 1;
+
+        const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
+
         constexpr bool fixup = false;
         mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
-            (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
-                blockIdx.x, blockIdx.y, 0, ne00/qk);
+            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
+             tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
         return;
     }
 #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
 
-    const     int64_t blocks_per_ne00 = ne00 / qk;
+    const     int64_t blocks_per_ne00 = ncols_x / qk;
     constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
 
-    const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
-    const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
-
     // kbc == k block continuous, current index in continuous ijk space.
-    int64_t kbc      = (int64_t) blockIdx.x     *blocks_per_ne00*ntx*nty / gridDim.x;
-    int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
+    int64_t kbc      = (int64_t) blockIdx.x     *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
+    int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
 
     kbc      -= (kbc      % blocks_per_ne00) % blocks_per_iter;
     kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
@@ -2649,13 +2706,64 @@ static __global__ void mul_mat_q(
     int kb0_start = kbc % blocks_per_ne00;
     int kb0_stop  = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
     while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
-        const int jt =  kbc /    (blocks_per_ne00*nty);                    // j index of current tile.
-        const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
+        int tmp = kbc;
+        const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+        tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+        const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
+        tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
+        const int zt = tmp / (ntx*blocks_per_ne00);
+        tmp -= zt * (ntx*blocks_per_ne00);
+        const int jt = tmp / blocks_per_ne00;
+
+        // Defaults for regular matrix multiplication:
+        int col_low    = 0;
+        int col_high   = ncols_y;
+        int col_diff   = ncols_y;
+        int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
+        int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
+
+        if (ids_dst) {
+            col_low  = expert_bounds[zt + 0];
+            col_high = expert_bounds[zt + 1];
+            col_diff = col_high - col_low;
+
+            offset_y   = 0;
+            offset_dst = 0;
+
+            if (jt*mmq_x >= col_diff) {
+                kbc += blocks_per_ne00;
+                kbc -= kbc % blocks_per_ne00;
+
+                kb0_start = 0;
+                kb0_stop  = min(blocks_per_ne00, kbc_stop - kbc);
+
+                continue;
+            }
+
+#pragma unroll
+            for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
+                const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+                if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+                    break;
+                }
+
+                ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
+            }
+        }
+
+        offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
+        offset_dst += it*mmq_y;
+
+        const int tile_x_max_i = nrows_x  - it*mmq_y - 1;
+        const int tile_y_max_j = col_diff - jt*mmq_x - 1;
+
+        const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
 
         constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
-            (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
-             it, jt, kb0_start, kb0_stop);
+            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
+             tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
 
         kbc += blocks_per_ne00;
         kbc -= kbc % blocks_per_ne00;
@@ -2668,55 +2776,106 @@ static __global__ void mul_mat_q(
         return;
     }
 
-    const int jt =  kbc /    (blocks_per_ne00*nty);
-    const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+    int tmp = kbc;
+    const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+    tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+    const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
+    tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
+    const int zt = tmp / (ntx*blocks_per_ne00);
+    tmp -= zt * (ntx*blocks_per_ne00);
+    const int jt = tmp / blocks_per_ne00;
+
+    // Defaults for regular matrix multiplication:
+    int col_low    = 0;
+    int col_high   = ncols_y;
+    int col_diff   = ncols_y;
+    int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
+    int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
+
+    if (ids_dst) {
+        col_low  = expert_bounds[zt + 0];
+        col_high = expert_bounds[zt + 1];
+        col_diff = col_high - col_low;
+
+        offset_y   = 0;
+        offset_dst = 0;
+
+        if (jt*mmq_x >= col_diff) {
+            return;
+        }
+
+        // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
+            const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+            if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+                break;
+            }
+
+            ids_dst_shared[j] = j;
+        }
+    }
+
+    offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
+    offset_dst += it*mmq_y;
+
+    const int tile_x_max_i = nrows_x  - it*mmq_y - 1;
+    const int tile_y_max_j = col_diff - jt*mmq_x - 1;
+
+    const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
 
     constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
-        (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
-            it, jt, kb0_start, kb0_stop);
+        (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
+         tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
 }
 
 
 template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 static __global__ void mul_mat_q_stream_k_fixup(
-    float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
-
+        const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
+        const int ncols_x, const int nrows_x, const int ncols_y, const int stride_col_dst,
+        const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
     constexpr int     mmq_y           = get_mmq_y_device();
     constexpr int     qk              = ggml_cuda_type_traits<type>::qk;
     constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
-    const     int64_t blocks_per_ne00 = ne00 / qk;
+    const     int64_t blocks_per_ne00 = ncols_x / qk;
 
     float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
 
-    const int ntx = (ne11 + mmq_x - 1) / mmq_x;
-    const int nty = (ne01 + mmq_y - 1) / mmq_y;
-
-    bool any_fixup = false;
+    const int ntx  = (ncols_y + mmq_x - 1) / mmq_x;
+    const int nty  = (nrows_x + mmq_y - 1) / mmq_y;
 
-    const int bidx_start = ((blockIdx.y*nty + blockIdx.x)     * block_num_mmq)                           / (gridDim.y*gridDim.x);
-    const int bidx_stop  = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
+    const int bidx0 = blockIdx.x;
 
-    int64_t kbc_0;
-    int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
-
-    for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
-        kbc_0 = kbc_stop_0;
-        kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
+    // kbc == k block continuous, current index in continuous ijk space.
+    int64_t kbc0      = (int64_t) bidx0     *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
+    int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
 
-        const int64_t kbc      = kbc_0      - (kbc_0      % blocks_per_ne00) % blocks_per_iter;
-        const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
+    kbc0      -= (kbc0      % blocks_per_ne00) % blocks_per_iter;
+    kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
 
-        // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
-        if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
-            continue;
-        }
+    const bool did_not_have_any_data   = kbc0 == kbc0_stop;
+    const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
+    const bool did_not_write_last      = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
+    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
+        return;
+    }
 
-        const int jt =  kbc_stop /    (blocks_per_ne00*nty);
-        const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+    bool any_fixup = false;
 
-        // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
-        if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) {
+    // Iterate over previous blocks and sum up partial sums written to fixup buffer.
+    // All CUDA blocks that get here must have a previous block that needs a fixup.
+    int64_t bidx = bidx0 - 1;
+    int64_t kbc_stop = kbc0;
+    while(true) {
+        int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
+        kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
+
+        if (kbc == kbc_stop) { // Did not have any data.
+            bidx--;
+            kbc_stop = kbc;
             continue;
         }
 
@@ -2733,16 +2892,71 @@ static __global__ void mul_mat_q_stream_k_fixup(
                 sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
             }
         }
+
+        // If this block started in a previous tile we are done and don't need to combine additional partial results.
+        if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
+            break;
+        }
+        bidx--;
+        kbc_stop = kbc;
     }
 
     if (!any_fixup) {
         return;
     }
 
-    dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
+    int tmp = kbc0;
+    const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+    tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+    const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
+    tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
+    const int zt = tmp / (ntx*blocks_per_ne00);
+    tmp -= zt * (ntx*blocks_per_ne00);
+    const int jt = tmp / blocks_per_ne00;
 
-    const int i_max = ne01 - blockIdx.x*mmq_y - 1;
-    const int j_max = ne11 - blockIdx.y*mmq_x - 1;
+    if (!ids_dst) {
+        const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
+        dst += offset_dst;
+
+        const int i_max = nrows_x - it*mmq_y - 1;
+        const int j_max = ncols_y - jt*mmq_x - 1;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (j > j_max) {
+                return;
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                if (need_check && i > i_max) {
+                    continue;
+                }
+
+                dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+            }
+        }
+        return;
+    }
+
+    __shared__ int ids_dst_shared[mmq_x];
+    const int col_low  = expert_bounds[zt + 0];
+    const int col_high = expert_bounds[zt + 1];
+    const int col_diff = col_high - col_low;
+
+    for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
+        ids_dst_shared[j] = ids_dst[col_low + j];
+    }
+
+    const int offset_dst = it*mmq_y;
+    dst += offset_dst;
+
+    const int i_max = nrows_x  - it*mmq_y - 1;
+    const int j_max = col_diff - jt*mmq_x - 1;
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -2760,26 +2974,27 @@ static __global__ void mul_mat_q_stream_k_fixup(
                 continue;
             }
 
-            dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+            dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
         }
     }
 }
 
 struct mmq_args {
-    const char * x; const char * y; float * dst;
-    int64_t ne00; int64_t ne01; int64_t stride01;
-    int64_t ne10; int64_t ne11; int64_t stride11;
-    int64_t ne0;
+    const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
+    int64_t ncols_x; int64_t nrows_x; int64_t ncols_y; int64_t stride_row_x; int64_t nrows_dst;
+    int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
+    int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
     bool use_stream_k;
 };
 
 template<ggml_type type>
-static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
+static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
     const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
-    const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
-    const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
-    return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
+    const size_t nbs_ids = mmq_x*sizeof(int);
+    const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
+    return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
 }
 
 template <ggml_type type, int mmq_x>
@@ -2791,86 +3006,114 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
     const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
 
-    const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
+    const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
 
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
-    static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
-    if (!shmem_limit_raised[id]) {
-        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
-        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
-        shmem_limit_raised[id] = true;
+    static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+    if (!shared_memory_limit_raised[id]) {
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
+        shared_memory_limit_raised[id] = true;
     }
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
 
-    const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
-    const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
-    const dim3 block_nums_xy_tiling(nty, ntx, 1);
+    const int nty  = (args.nrows_x + mmq_y - 1) / mmq_y;
+    const int ntx  = (args.ncols_y + mmq_x - 1) / mmq_x;
+    const int ntzw = args.nchannels_y * args.nsamples_y;
+    const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
+
+    GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0);
+    GGML_ASSERT(args.nsamples_y  % args.nsamples_x  == 0);
+    const int channel_ratio = args.nchannels_y / args.nchannels_x;
+    const int sample_ratio  = args.nsamples_y  / args.nsamples_x;
 
     if (!args.use_stream_k) {
-        if (args.ne01 % mmq_y == 0) {
+        if (args.nrows_x % mmq_y == 0) {
             constexpr bool need_check = false;
-            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
-                (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
+                (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
+                 args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
+                 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
+                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
         } else {
             constexpr bool need_check = true;
-            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
-                (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
+                (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
+                 args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
+                 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
+                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
         }
         return;
     }
 
-    const dim3 block_nums_mmq(nsm, 1, 1);
+    const dim3 block_nums_stream_k(nsm, 1, 1);
+    const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
 
     ggml_cuda_pool & pool = ctx.pool(id);
-    ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
+    ggml_cuda_pool_alloc<float> tmp_fixup(pool);
+    if (fixup_needed) {
+        tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
+    }
 
-    if (args.ne01 % mmq_y == 0) {
+    if (args.nrows_x % mmq_y == 0) {
         constexpr bool need_check = false;
 
-        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
-            (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
+            (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
+             args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
+             channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
+             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+
+        if (!fixup_needed) {
+            return;
+        }
 
-        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
-            (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
+        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
+            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
+             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
     } else {
         constexpr bool need_check = true;
 
-        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
-            (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
+            (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
+             args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
+             channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
+             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+
+        if (!fixup_needed) {
+            return;
+        }
 
-        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
-            (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
+        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
+            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
+             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
     }
 }
 
 template <ggml_type type>
 void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
-    const int id    = ggml_cuda_get_device();
-    const int cc    = ggml_cuda_info().devices[id].cc;
-    const int smpbo = ggml_cuda_info().devices[id].smpbo;
+    const int    id    = ggml_cuda_get_device();
+    const int    cc    = ggml_cuda_info().devices[id].cc;
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
 
     const int mmq_x_max = get_mmq_x_max_host(cc);
     const int mmq_y = get_mmq_y_host(cc);
-    const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
-    const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
 
     int mmq_x_best  = 0;
-    int nparts_best = INT_MAX;
+    int ntiles_x_best = INT_MAX;
 
-    for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
+    for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
         const int granularity = mmq_get_granularity_host(mmq_x, cc);
 
-        if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
+        if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
             continue;
         }
 
-        const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
-        const int nwaves_xy_tiling = ntiles_x*block_num_y;
-        const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
+        const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
 
-        if (nparts < nparts_best) {
-            mmq_x_best  = mmq_x;
-            nparts_best = nparts;
+        if (ntiles_x < ntiles_x_best) {
+            mmq_x_best = mmq_x;
+            ntiles_x_best = ntiles_x;
         }
     }
 
@@ -2954,6 +3197,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
 
 // -------------------------------------------------------------------------------------------------------------------------
 
+void ggml_cuda_mul_mat_q(
+        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
 void ggml_cuda_op_mul_mat_q(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
index d846e35a6a26d87848e314de7877ba35c30f05e6..132c466fd1aa6bc2d22311b2b7ee611aa9f8aebd 100644 (file)
@@ -158,7 +158,7 @@ static __global__ void mul_mat_vec_q(
     const     int blocks_per_row_x = ncols_x / qk;
     constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 
-    // The MUL_MAT_ID code path with ids != nullptr is only implemetned for ncols_dst == 1.
+    // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
     const int channel_dst = blockIdx.y;
     const int channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]          : channel_dst / channel_ratio;
     const int channel_y   = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
@@ -507,7 +507,7 @@ void ggml_cuda_mul_mat_vec_q(
     GGML_ASSERT(        nb0        == ts_dst);
     GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
 
-    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
+    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
 
     const float   * src1_d =       (const float   *) src1->data;
     const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
@@ -519,7 +519,7 @@ void ggml_cuda_mul_mat_vec_q(
         const int64_t s11 = src1->nb[1] / ts_src1;
         const int64_t s12 = src1->nb[2] / ts_src1;
         const int64_t s13 = src1->nb[3] / ts_src1;
-        quantize_row_q8_1_cuda(src1_d, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
+        quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
     }
 
     const int64_t s01 = src0->nb[1] / ts_src0;
index 3bab47d56a22ee98b4f52b80979d2f155115d962..931a45ad347dcca02717191a2cd37d0e799c73c1 100644 (file)
@@ -49,29 +49,38 @@ static __global__ void quantize_q8_1(
 
 template <mmq_q8_1_ds_layout ds_layout>
 static __global__ void quantize_mmq_q8_1(
-    const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
+        const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
+        const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+        const int64_t ne0, const int ne1, const int ne2) {
 
     constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
     constexpr int vals_per_sum   = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
 
-    const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
+    const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
 
-    if (ix0 >= kx0_padded) {
+    if (i0 >= ne0) {
         return;
     }
 
-    const float4 * x4 = (const float4 *) x;
+    const int64_t i1 = blockIdx.y;
+    const int64_t i2 = blockIdx.z % ne2;
+    const int64_t i3 = blockIdx.z / ne2;
 
-    const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
+    const int64_t i00 = i0;
+    const int64_t i01 = ids ? ids[i1] : i1;
+    const int64_t i02 = i2;
+    const int64_t i03 = i3;
+
+    const float4 * x4 = (const float4 *) x;
 
     block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
 
     const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
-    const int64_t ib  = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y;                   // block index in channel
-    const int64_t iqs = ix0 % (4*QK8_1);                                            // quant index in block
+    const int64_t ib  = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y;                    // block index in channel
+    const int64_t iqs = i0 % (4*QK8_1);                                             // quant index in block
 
     // Load 4 floats per thread and calculate max. abs. value between them:
-    const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
+    const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
     float amax = fabsf(xi.x);
     amax = fmaxf(amax, fabsf(xi.y));
     amax = fmaxf(amax, fabsf(xi.z));
@@ -87,7 +96,7 @@ static __global__ void quantize_mmq_q8_1(
     if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
         sum = xi.x + xi.y + xi.z + xi.w;
 
-        // Exchange calculate sum across vals_per_sum/4 threads.
+        // Calculate sums across vals_per_sum/4 threads.
 #pragma unroll
         for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
             sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
@@ -137,9 +146,10 @@ static __global__ void quantize_mmq_q8_1(
 }
 
 void quantize_row_q8_1_cuda(
-    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
-
+        const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
+        const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+        const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
+    GGML_ASSERT(!ids);
     GGML_ASSERT(ne0 % QK8_1 == 0);
 
     const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
@@ -150,9 +160,9 @@ void quantize_row_q8_1_cuda(
 }
 
 void quantize_mmq_q8_1_cuda(
-    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
-
+        const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
+        const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+        const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
     GGML_ASSERT(ne0 % (4*QK8_1) == 0);
 
     const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
@@ -161,21 +171,18 @@ void quantize_mmq_q8_1_cuda(
     switch (mmq_get_q8_1_ds_layout(type_src0)) {
         case MMQ_Q8_1_DS_LAYOUT_D4:
             quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
-                <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
+                <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
             break;
         case MMQ_Q8_1_DS_LAYOUT_DS4:
             quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
-                <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
+                <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
             break;
         case MMQ_Q8_1_DS_LAYOUT_D2S6:
             quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
-                <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
+                <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
             break;
         default:
             GGML_ABORT("fatal error");
             break;
     }
-    GGML_UNUSED(s01);
-    GGML_UNUSED(s02);
-    GGML_UNUSED(s03);
 }
index b627c4e4008b4a3bc010c8251b7ea4fe8702c57c..725ab52443c0e9e8720694b7e06d3c6db1c16fd9 100644 (file)
@@ -12,13 +12,16 @@ static_assert(MATRIX_ROW_PADDING %    CUDA_QUANTIZE_BLOCK_SIZE      == 0, "Risk
 static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
 
 typedef void (*quantize_cuda_t)(
-    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
+        const float * x, const int32_t * ids, void * vy,
+        ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
+        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
 
 void quantize_row_q8_1_cuda(
-    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
+        const float * x, const int32_t * ids, void * vy,
+        ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
+        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
 
 void quantize_mmq_q8_1_cuda(
-    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
+        const float * x, const int32_t * ids, void * vy,
+        ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
+        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
index d70acb77194352ac366634917d94471ffd3999a9..9591b1a89e72396c810be0c7b875dfed5e4d5f2f 100644 (file)
@@ -4184,6 +4184,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+
+            // test cases with large ne00/ne10 to cover stream-k fixup
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 1024, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 1024, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
         }
     }
     for (ggml_type type_a : other_types) {