]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID (#13014)
authorJohannes Gäßler <redacted>
Tue, 22 Apr 2025 19:27:40 +0000 (21:27 +0200)
committerGitHub <redacted>
Tue, 22 Apr 2025 19:27:40 +0000 (21:27 +0200)
* CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID

* fix logic for RoPE support, CUDA graphs

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmv.cu
ggml/src/ggml-cuda/mmv.cuh
ggml/src/ggml-cuda/mmvq.cu
ggml/src/ggml-cuda/mmvq.cuh
ggml/src/ggml-cuda/quantize.cu
ggml/src/ggml-cuda/quantize.cuh
ggml/src/ggml-cuda/vecdotq.cuh
tests/test-backend-ops.cpp

index a7febef723c2ecb01ca9aeb027ee3f5d6c955d8f..e0e0d2137f3be28ff70304b5caf14c1e257142de 100644 (file)
@@ -1410,6 +1410,11 @@ static void ggml_cuda_op_mul_mat(
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
+    // const int64_t nb10 = src1->nb[0];
+    const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2];
+    const int64_t nb13 = src1->nb[3];
+
     const int64_t nb2 = dst->nb[2];
     const int64_t nb3 = dst->nb[3];
 
@@ -1545,7 +1550,10 @@ static void ggml_cuda_op_mul_mat(
             dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
 
             if (src1_on_device && src1_is_contiguous) {
-                quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
+                quantize_src1(
+                    dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10,
+                    nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
+                    src1_padded_col_size, ne11, ne12, ne13, stream);
                 CUDA_CHECK(cudaGetLastError());
             }
         }
@@ -1640,7 +1648,9 @@ static void ggml_cuda_op_mul_mat(
                 }
 
                 if (quantize_src1 && !src1_is_contiguous) {
-                    quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
+                    quantize_src1(
+                        src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
+                        src1_padded_col_size, src1_ncols, 1, 1, stream);
                     CUDA_CHECK(cudaGetLastError());
                 }
 
@@ -1878,7 +1888,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
 
-    bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
+    bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
         && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
@@ -1919,10 +1929,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
+    if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
-        ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
+        ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
+    } else if (!split && use_mul_mat_vec_q) {
+        ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
     } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
                && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
@@ -1999,6 +2011,15 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
+    if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ne2 == 1) {
+        if (ggml_is_quantized(src0->type)) {
+            ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
+        } else {
+            ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
+        }
+        return;
+    }
+
     GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
 
     cudaStream_t stream = ctx.stream();
@@ -2035,97 +2056,75 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
     dst_row.nb[2] = nb1;
     dst_row.nb[3] = nb1;
 
-    if (ne12 == 1) {
-        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-            for (int64_t id = 0; id < n_ids; id++) {
-                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
-
-                GGML_ASSERT(i02 >= 0 && i02 < n_as);
-
-                const int64_t i11 = id % ne11;
-                const int64_t i12 = iid1;
-
-                const int64_t i1 = id;
-                const int64_t i2 = i12;
-
-                src0_row.data = src0_original + i02*nb02;
-                src1_row.data = src1_original + i11*nb11 + i12*nb12;
-                dst_row.data  =  dst_original + i1*nb1   + i2*nb2;
-
-                ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
-            }
-        }
-    } else {
-        ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
-        ggml_cuda_pool_alloc<char>  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
-
-        src1_row.data = src1_contiguous.get();
-        dst_row.data  =  dst_contiguous.get();
+    ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
+    ggml_cuda_pool_alloc<char>  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
 
-        for (int64_t i02 = 0; i02 < n_as; i02++) {
-            int64_t num_src1_rows = 0;
+    src1_row.data = src1_contiguous.get();
+    dst_row.data  =  dst_contiguous.get();
 
-            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-                for (int64_t id = 0; id < n_ids; id++) {
-                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+    for (int64_t i02 = 0; i02 < n_as; i02++) {
+        int64_t num_src1_rows = 0;
 
-                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+            for (int64_t id = 0; id < n_ids; id++) {
+                const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
 
-                    if (row_id_i != i02) {
-                        continue;
-                    }
+                GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
 
-                    num_src1_rows++;
+                if (row_id_i != i02) {
+                    continue;
                 }
-            }
 
-            if (num_src1_rows == 0) {
-                continue;
+                num_src1_rows++;
             }
+        }
 
-            ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
-            ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
-            CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
+        if (num_src1_rows == 0) {
+            continue;
+        }
 
-            {
-                dim3 block_dims(std::min((unsigned int)ne10, 768u));
-                dim3 grid_dims(ids->ne[1], n_ids);
-                k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
-                        src1_original, src1_contiguous.get(),
-                        dev_cur_src1_row.get(), dev_row_mapping.get(),
-                        ids_dev, i02, ids->nb[1], ids->nb[0],
-                        ne11, ne10,
-                        nb11, nb12);
-                CUDA_CHECK(cudaGetLastError());
-            }
+        ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
+        ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
+        CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
+
+        {
+            dim3 block_dims(std::min((unsigned int)ne10, 768u));
+            dim3 grid_dims(ids->ne[1], n_ids);
+            k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+                    src1_original, src1_contiguous.get(),
+                    dev_cur_src1_row.get(), dev_row_mapping.get(),
+                    ids_dev, i02, ids->nb[1], ids->nb[0],
+                    ne11, ne10,
+                    nb11, nb12);
+            CUDA_CHECK(cudaGetLastError());
+        }
 
-            src0_row.data = src0_original + i02*nb02;
+        src0_row.data = src0_original + i02*nb02;
 
-            GGML_ASSERT(nb11 == sizeof(float)*ne10);
-            GGML_ASSERT(nb1 == sizeof(float)*ne0);
+        GGML_ASSERT(nb11 == sizeof(float)*ne10);
+        GGML_ASSERT(nb1 == sizeof(float)*ne0);
 
-            src1_row.ne[1] = num_src1_rows;
-            src1_row.nb[1] = nb11;
-            src1_row.nb[2] = num_src1_rows*nb11;
-            src1_row.nb[3] = num_src1_rows*nb11;
+        src1_row.ne[1] = num_src1_rows;
+        src1_row.nb[1] = nb11;
+        src1_row.nb[2] = num_src1_rows*nb11;
+        src1_row.nb[3] = num_src1_rows*nb11;
 
-            dst_row.ne[1] = num_src1_rows;
-            dst_row.nb[1] = nb1;
-            dst_row.nb[2] = num_src1_rows*nb1;
-            dst_row.nb[3] = num_src1_rows*nb1;
+        dst_row.ne[1] = num_src1_rows;
+        dst_row.nb[1] = nb1;
+        dst_row.nb[2] = num_src1_rows*nb1;
+        dst_row.nb[3] = num_src1_rows*nb1;
 
-            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+        ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
 
-            {
-                dim3 block_dims(std::min((unsigned int)ne0, 768u));
-                dim3 grid_dims(num_src1_rows);
-                k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
-                        dst_original, dst_contiguous.get(),
-                        dev_row_mapping.get(),
-                        ne0,
-                        nb1, nb2);
-                CUDA_CHECK(cudaGetLastError());
-            }
+        {
+            dim3 block_dims(std::min((unsigned int)ne0, 768u));
+            dim3 grid_dims(num_src1_rows);
+            k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+                    dst_original, dst_contiguous.get(),
+                    dev_row_mapping.get(),
+                    ne0,
+                    nb1, nb2);
+            CUDA_CHECK(cudaGetLastError());
         }
     }
 }
@@ -2489,7 +2488,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 #endif
         }
 
-        if (node->op == GGML_OP_MUL_MAT_ID) {
+        if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
             use_cuda_graph = false; // This node type is not supported by CUDA graph capture
 #ifndef NDEBUG
             GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
@@ -3203,9 +3202,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         }
         case GGML_OP_ROPE:
         case GGML_OP_ROPE_BACK: {
-            const size_t ts = ggml_type_size(op->src[0]->type);
-            const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
-            return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
+            return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
         }
         case GGML_OP_IM2COL:
         case GGML_OP_POOL_2D:
index b39961cd1154da527043eff044db47001cc3a8a1..d8c385e2399aea4bbd6354d89d06f680eda6f58f 100644 (file)
@@ -4,18 +4,23 @@
 
 template <typename T, typename type_acc, int block_size>
 static __global__ void mul_mat_vec(
-        const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
+        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
         const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
         const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
-    const int64_t row       = blockIdx.x;
-    const int64_t channel   = blockIdx.y;
-    const int64_t sample    = blockIdx.z;
-    const int     tid       = threadIdx.x;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-
-    x   +=  (sample/sample_ratio)*stride_sample_x   + (channel/channel_ratio)*stride_channel_x + row*stride_row;
-    y   +=   sample              *stride_sample_y   +  channel               *stride_channel_y;
-    dst +=   sample              *stride_sample_dst +  channel               *stride_channel_dst;
+    const int64_t row         = blockIdx.x;
+    const int64_t channel_dst = blockIdx.y;
+    const int64_t channel_x   = ids ? ids[channel_dst]          : channel_dst / channel_ratio;
+    const int64_t channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
+    const int64_t sample_dst  = blockIdx.z;
+    const int64_t sample_x    = sample_dst / sample_ratio;
+    const int64_t sample_y    = sample_dst;
+    const int     tid         = threadIdx.x;
+    constexpr int warp_size   = ggml_cuda_get_physical_warp_size();
+
+    x   += sample_x  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
+    y   += sample_y  *stride_sample_y   + channel_y  *stride_channel_y;
+    dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
 
     const float2 * y2 = (const float2 *) y;
 
@@ -31,12 +36,19 @@ static __global__ void mul_mat_vec(
 
     float sumf = 0.0f;
 
-    if constexpr (std::is_same<T, half>::value) {
+    if constexpr (std::is_same<T, float>::value) {
+        const float2 * x2 = (const float2 *) x;
+
+        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+            const float2 tmpx = x2[col2];
+            const float2 tmpy = y2[col2];
+            sumf += tmpx.x*tmpy.x;
+            sumf += tmpx.y*tmpy.y;
+        }
+    } else if constexpr (std::is_same<T, half>::value) {
         const half2 * x2 = (const half2 *) x;
 
         if (std::is_same<type_acc, float>::value) {
-            sumf = 0.0f;
-
             for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
                 const float2 tmpx = __half22float2(x2[col2]);
                 const float2 tmpy = y2[col2];
@@ -59,8 +71,6 @@ static __global__ void mul_mat_vec(
         }
     } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
         const int * x2 = (const int *) x;
-        sumf = 0.0f;
-
         for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
             const int    tmpx = x2[col2];
             const float2 tmpy = y2[col2];
@@ -92,17 +102,17 @@ static __global__ void mul_mat_vec(
 
 template <typename T, typename type_acc>
 static void launch_mul_mat_vec_cuda(
-        const T * x, const float * y, float * dst,
-        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
-        const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         cudaStream_t stream) {
     GGML_ASSERT(ncols      % 2 == 0);
     GGML_ASSERT(stride_row % 2 == 0);
-    GGML_ASSERT(nchannels_y % nchannels_x == 0);
-    GGML_ASSERT(nsamples_y  % nsamples_x  == 0);
-    const int64_t channel_ratio = nchannels_y / nchannels_x;
-    const int64_t sample_ratio  = nsamples_y  / nsamples_x;
+    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const int64_t channel_ratio = nchannels_dst / nchannels_x;
+    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
     int device;
     int warp_size;
 
@@ -124,48 +134,48 @@ static void launch_mul_mat_vec_cuda(
     }
 
     const int smem = warp_size*sizeof(float);
-    const dim3 block_nums(nrows, nchannels_y, nsamples_y);
+    const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
             mul_mat_vec<T, type_acc,  32><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   64: {
             mul_mat_vec<T, type_acc,  64><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   96: {
             mul_mat_vec<T, type_acc,  96><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  128: {
             mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  160: {
             mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  192: {
             mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  224: {
             mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  256: {
             mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -175,28 +185,28 @@ static void launch_mul_mat_vec_cuda(
 
 template<typename T>
 static void mul_mat_vec_cuda(
-        const T * x, const float * y, float * dst,
-        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
-        const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         enum ggml_prec prec, cudaStream_t stream) {
-    switch (prec) {
-        case GGML_PREC_DEFAULT: {
+    if constexpr(std::is_same<T, half>::value) {
+        if (prec == GGML_PREC_DEFAULT) {
             launch_mul_mat_vec_cuda<T, half>
-                (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-        } break;
-        case GGML_PREC_F32: {
-            launch_mul_mat_vec_cuda<T, float>
-                (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-        } break;
+                (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            return;
+        }
     }
+    launch_mul_mat_vec_cuda<T, float>
+        (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+         stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
 }
 
-void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
+    GGML_ASSERT(         dst->type == GGML_TYPE_F32);
 
     GGML_TENSOR_BINARY_OP_LOCALS;
 
@@ -204,21 +214,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const size_t ts_src1 = ggml_type_size(src1->type);
     const size_t ts_dst  = ggml_type_size(dst->type);
 
-    GGML_ASSERT(ne11 == 1);
-    GGML_ASSERT(ne12 == ne2);
+    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
     GGML_ASSERT(ne13 == ne3);
 
-    GGML_ASSERT(nb00 == ts_src0);
-    GGML_ASSERT(nb10 == ts_src1);
-    GGML_ASSERT(nb0  == ts_dst);
+    GGML_ASSERT(        nb00       == ts_src0);
+    GGML_ASSERT(        nb10       == ts_src1);
+    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+    GGML_ASSERT(        nb0        == ts_dst);
 
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
 
-    const float * src1_d = (const float *) src1->data;
-    float       *  dst_d = (float       *)  dst->data;
+    const float   * src1_d =       (const float   *) src1->data;
+    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
+    float         *  dst_d =       (float         *)  dst->data;
 
     const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s11 = src1->nb[1] / ts_src1;
+    const int64_t s1  =  dst->nb[1] / ts_dst;
     const int64_t s02 = src0->nb[2] / ts_src0;
     const int64_t s12 = src1->nb[2] / ts_src1;
     const int64_t s2  =  dst->nb[2] / ts_dst;
@@ -226,14 +239,33 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const int64_t s13 = src1->nb[3] / ts_src1;
     const int64_t s3  =  dst->nb[3] / ts_dst;
 
+    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+    const int64_t ncols_dst          = ids ? ne2  : ne1;
+    const int64_t nchannels_y        = ids ? ne11 : ne12;
+    const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_channel_dst = ids ? s1   : s2;
+    const int64_t stride_channel_y   = ids ? s11  : s12;
+
+    GGML_ASSERT(ncols_dst == 1);
+
     switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0->data;
+            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+        } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
+            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
+            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -262,27 +294,34 @@ void ggml_cuda_op_mul_mat_vec(
     const int64_t stride_row         = ne00;
     const int64_t nchannels_x        = 1;
     const int64_t nchannels_y        = 1;
+    const int64_t nchannels_dst      = 1;
     const int64_t stride_channel_x   = 0;
     const int64_t stride_channel_y   = 0;
     const int64_t stride_channel_dst = 0;
     const int64_t nsamples_x         = 1;
-    const int64_t nsamples_y         = 1;
+    const int64_t nsamples_dst       = 1;
     const int64_t stride_sample_x    = 0;
     const int64_t stride_sample_y    = 0;
     const int64_t stride_sample_dst  = 0;
 
     switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0_dd_i;
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+        } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-                nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-                nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
index 78a1cd4a6906af62f59cd58b7586852f3b6988fc..756e7e1cc7fc36cd5874aeb5719157f5dd50c851 100644 (file)
@@ -3,7 +3,7 @@
 // maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
 #define MMV_MAX_ROWS 512
 
-void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
 
 void ggml_cuda_op_mul_mat_vec(
     ggml_backend_cuda_context & ctx,
index eef8585a7380abfc838ddcb39da0587a12880487..cac04916cd8f0b5970f54f95b5e972b919a92ebe 100644 (file)
@@ -1,6 +1,9 @@
 #include "mmvq.cuh"
+#include "quantize.cuh"
 #include "vecdotq.cuh"
 
+#include <cstdint>
+
 typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
 
 static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
@@ -73,9 +76,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
     return MMVQ_PARAMETERS_GENERIC;
 }
 
-static constexpr __host__ __device__ int calc_nwarps(int ncols_y,  mmvq_parameter_table_id table_id) {
+static constexpr __host__ __device__ int calc_nwarps(int ncols_dst,  mmvq_parameter_table_id table_id) {
     if (table_id == MMVQ_PARAMETERS_GENERIC) {
-        switch (ncols_y) {
+        switch (ncols_dst) {
             case 1:
             case 2:
             case 3:
@@ -90,7 +93,7 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y,  mmvq_paramete
                 return 1;
         }
     } else if (table_id == MMVQ_PARAMETERS_GCN) {
-        switch (ncols_y) {
+        switch (ncols_dst) {
             case 1:
             case 2:
             case 3:
@@ -107,9 +110,9 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y,  mmvq_paramete
     return 1;
 }
 
-static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
+static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
     if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
-        switch (ncols_y) {
+        switch (ncols_dst) {
             case 1:
                 return 1;
             case 2:
@@ -127,19 +130,21 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int ta
     return 1;
 }
 
-template <ggml_type type, int ncols_y>
+template <ggml_type type, int ncols_dst>
 // tell the compiler to use as many registers as it wants, see nwarps definition below
-__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
+__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+        const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
 
     constexpr int qk  = ggml_cuda_type_traits<type>::qk;
     constexpr int qi  = ggml_cuda_type_traits<type>::qi;
     constexpr int vdr = get_vdr_mmvq(type);
     constexpr mmvq_parameter_table_id table_id = get_device_table_id();
-    constexpr int nwarps = calc_nwarps(ncols_y, table_id);
-    constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
+    constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
+    constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
@@ -147,13 +152,21 @@ static __global__ void mul_mat_vec_q(
     const     int tid = warp_size*threadIdx.y + threadIdx.x;
     const     int row0 = rows_per_cuda_block*blockIdx.x;
     const     int blocks_per_row_x = ncols_x / qk;
-    const     int blocks_per_col_y = nrows_y / QK8_1;
     constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 
+    // The MUL_MAT_ID code path with ids != nullptr is only implemetned for ncols_dst == 1.
+    const int channel_dst = blockIdx.y;
+    const int channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]          : channel_dst / channel_ratio;
+    const int channel_y   = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
+    const int sample_dst  = blockIdx.z;
+    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_y    = sample_dst;
+
     // partial sum for each thread
-    float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
+    float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
 
-    const block_q8_1 * y = (const block_q8_1 *) vy;
+    const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
+    const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
 
     for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
         const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
@@ -162,18 +175,19 @@ static __global__ void mul_mat_vec_q(
         const int kqs = vdr * (tid % (qi/vdr));
 
 #pragma unroll
-        for (int j = 0; j < ncols_y; ++j) {
+        for (int j = 0; j < ncols_dst; ++j) {
 #pragma unroll
             for (int i = 0; i < rows_per_cuda_block; ++i) {
-                tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
+                tmp[j][i] += vec_dot_q_cuda(
+                    vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
             }
         }
     }
 
-    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
+    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
     if (threadIdx.y > 0) {
 #pragma unroll
-        for (int j = 0; j < ncols_y; ++j) {
+        for (int j = 0; j < ncols_dst; ++j) {
 #pragma unroll
             for (int i = 0; i < rows_per_cuda_block; ++i) {
                 tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
@@ -185,9 +199,11 @@ static __global__ void mul_mat_vec_q(
         return;
     }
 
+    dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
+
     // sum up partial sums and write back result
 #pragma unroll
-    for (int j = 0; j < ncols_y; ++j) {
+    for (int j = 0; j < ncols_dst; ++j) {
 #pragma unroll
         for (int i = 0; i < rows_per_cuda_block; ++i) {
 #pragma unroll
@@ -197,88 +213,121 @@ static __global__ void mul_mat_vec_q(
             tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
         }
 
-        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
-            dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
+        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + int(threadIdx.x) < stride_col_dst)) {
+            dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
         }
     }
-
-    GGML_UNUSED(nrows_x);
 }
 
-static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
-    const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
-    const dim3 block_nums(nblocks, 1, 1);
-    const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
+static std::pair<dim3, dim3> calc_launch_params(
+        const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
+        const int warp_size, const mmvq_parameter_table_id table_id) {
+    const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
+    const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
+    const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
     return {block_nums, block_dims};
 }
 
 template <ggml_type type>
-static void mul_mat_vec_q_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+static void mul_mat_vec_q_switch_ncols_dst(
+        const void * vx, const void * vy, const int32_t * ids, float * dst,
+        const int ncols_x, const int nrows_x, const int ncols_dst,
+        const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+        const int nchannels_x, const int nchannels_y, const int nchannels_dst,
+        const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+        cudaStream_t stream) {
 
     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
-    GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
+    GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
+
+    const int channel_ratio = nchannels_dst / nchannels_x;
+    const int sample_ratio  = nsamples_dst  / nsamples_x;
 
     const int device = ggml_cuda_get_device();
     const int warp_size = ggml_cuda_info().devices[device].warp_size;
     const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
 
-    switch (ncols_y) {
+    GGML_ASSERT(!ids || ncols_dst == 1);
+    switch (ncols_dst) {
         case 1:
         {
-            constexpr int c_ncols_y = 1;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 1;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 2:
         {
-            constexpr int c_ncols_y = 2;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 2;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 3:
         {
-            constexpr int c_ncols_y = 3;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 3;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 4:
         {
-            constexpr int c_ncols_y = 4;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 4;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 5:
         {
-            constexpr int c_ncols_y = 5;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 5;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 6:
         {
-            constexpr int c_ncols_y = 6;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 6;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 7:
         {
-            constexpr int c_ncols_y = 7;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 7;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 8:
         {
-            constexpr int c_ncols_y = 8;
-            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 8;
+            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         default:
@@ -287,221 +336,241 @@ static void mul_mat_vec_q_cuda(
     }
 }
 
-static void mul_mat_vec_q4_0_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q4_1_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q5_0_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q5_1_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q8_0_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q2_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q3_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q4_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q5_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q6_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq2_xxs_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq2_xs_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq2_s_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq3_xxs_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq1_s_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq1_m_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq4_nl_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq4_xs_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq3_s_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-void ggml_cuda_op_mul_mat_vec_q(
-    ggml_backend_cuda_context & ctx,
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
-    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
-    const int64_t src1_padded_row_size, cudaStream_t stream) {
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t row_diff = row_high - row_low;
-
-    const int64_t ne10 = src1->ne[0];
-    GGML_ASSERT(ne10 % QK8_1 == 0);
-
-    const int64_t ne0 = dst->ne[0];
-
-    int id = ggml_cuda_get_device();
-
-    // the main device has a larger memory buffer to hold the results from all GPUs
-    // nrows_dst == nrows of the matrix that the kernel writes into
-    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
-
-    switch (src0->type) {
+static void mul_mat_vec_q_switch_type(
+        const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
+        const int ncols_x, const int nrows_x, const int ncols_dst,
+        const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+        const int nchannels_x, const int nchannels_y, const int nchannels_dst,
+        const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+        cudaStream_t stream) {
+    switch (type_x) {
         case GGML_TYPE_Q4_0:
-            mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_Q4_1:
-            mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_Q5_0:
-            mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_Q5_1:
-            mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_Q8_0:
-            mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_Q2_K:
-            mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_Q3_K:
-            mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_Q4_K:
-            mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_Q5_K:
-            mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_Q6_K:
-            mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_IQ2_XXS:
-            mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_IQ2_XS:
-            mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_IQ2_S:
-            mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_IQ3_XXS:
-            mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_IQ1_S:
-            mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_IQ1_M:
-            mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_IQ4_NL:
-            mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_IQ4_XS:
-            mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         case GGML_TYPE_IQ3_S:
-            mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
             break;
         default:
             GGML_ABORT("fatal error");
             break;
     }
+}
+
+void ggml_cuda_mul_mat_vec_q(
+        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(        dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(!ids || ids->type  == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    cudaStream_t stream = ctx.stream();
+
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
+
+    GGML_ASSERT(        nb00       == ts_src0);
+    GGML_ASSERT(        nb10       == ts_src1);
+    GGML_ASSERT(        nb0        == ts_dst);
+    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+
+    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
+
+    const float   * src1_d =       (const float   *) src1->data;
+    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
+    float         *  dst_d =       (float         *)  dst->data;
+
+    const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+    ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
+    {
+        const int64_t s11 = src1->nb[1] / ts_src1;
+        const int64_t s12 = src1->nb[2] / ts_src1;
+        const int64_t s13 = src1->nb[3] / ts_src1;
+        quantize_row_q8_1_cuda(src1_d, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
+    }
+
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s11 = ne10_padded / QK8_1;
+    const int64_t s1  =  dst->nb[1] / ts_dst;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
+
+    const int64_t s12 = ne11*s11;
+    const int64_t s13 = ne12*s12;
+
+    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+    const int64_t ncols_dst          = ids ? ne2  : ne1;
+    const int64_t nchannels_y        = ids ? ne11 : ne12;
+    const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_col_dst     = ids ? s2   : s1;
+    const int64_t stride_col_y       = ids ? s12  : s11;
+    const int64_t stride_channel_dst = ids ? s1   : s2;
+    const int64_t stride_channel_y   = ids ? s11  : s12;
+
+    mul_mat_vec_q_switch_type(
+        src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
+        ne01,              ncols_dst,     s01, stride_col_y,     stride_col_dst,
+        ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+        ne03,              ne3,           s03, s13,              s3,                 stream);
+}
+
+void ggml_cuda_op_mul_mat_vec_q(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    const int64_t ne10 = src1->ne[0];
+    GGML_ASSERT(ne10 % QK8_1 == 0);
+
+    const int64_t ne0 = dst->ne[0];
+
+    int id = ggml_cuda_get_device();
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // nrows_dst == nrows of the matrix that the kernel writes into
+    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+    const int stride_row_x = ne00 / ggml_blck_size(src0->type);
+    const int stride_col_y = src1_padded_row_size / QK8_1;
+
+    mul_mat_vec_q_switch_type(
+        src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
+        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
 
     GGML_UNUSED(src1);
     GGML_UNUSED(dst);
index d9e42fdd6d16c01b2594ffc9e6012aa0e98f37a5..39dc7d33eb5ac68aadd31ada9d02639fe5a13372 100644 (file)
@@ -2,6 +2,9 @@
 
 #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
 
+void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
 void ggml_cuda_op_mul_mat_vec_q(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
index 1702e4ce2feba476c0637f32d3280df81cf2f4c4..3bab47d56a22ee98b4f52b80979d2f155115d962 100644 (file)
@@ -1,30 +1,40 @@
 #include "quantize.cuh"
 #include <cstdint>
 
-static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
-    const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void quantize_q8_1(
+        const float * __restrict__ x, void * __restrict__ vy,
+        const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+        const int64_t ne0, const int ne1, const int ne2) {
+    const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (ix0 >= kx0_padded) {
+    if (i0 >= ne0) {
         return;
     }
 
-    const int64_t ix1 = blockIdx.y;
+    const int64_t i1 = blockIdx.y;
+    const int64_t i2 = blockIdx.z % ne2;
+    const int64_t i3 = blockIdx.z / ne2;
+
+    const int64_t & i00 = i0;
+    const int64_t & i01 = i1;
+    const int64_t & i02 = i2;
+    const int64_t & i03 = i3;
 
-    const int64_t i_padded = ix1*kx0_padded + ix0;
+    const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
 
     block_q8_1 * y = (block_q8_1 *) vy;
 
-    const int64_t ib = i_padded / QK8_1; // block index
-    const int64_t iqs = i_padded % QK8_1; // quant index
+    const int64_t ib  = i_cont / QK8_1; // block index
+    const int64_t iqs = i_cont % QK8_1; // quant index
 
-    const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
+    const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
     float amax = fabsf(xi);
     float sum = xi;
 
     amax = warp_reduce_max(amax);
-    sum = warp_reduce_sum(sum);
+    sum  = warp_reduce_sum(sum);
 
-    const float d = amax / 127;
+    const float  d = amax / 127;
     const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
 
     y[ib].qs[iqs] = q;
@@ -127,43 +137,45 @@ static __global__ void quantize_mmq_q8_1(
 }
 
 void quantize_row_q8_1_cuda(
-    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
-    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
 
-    GGML_ASSERT(kx0_padded % QK8_1 == 0);
+    GGML_ASSERT(ne0 % QK8_1 == 0);
 
-    const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
-    const dim3 num_blocks(block_num_x, kx1*channels, 1);
+    const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
     const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
-    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
-
-    GGML_UNUSED(type_x);
+    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
+    GGML_UNUSED(type_src0);
 }
 
 void quantize_mmq_q8_1_cuda(
-    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
-    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
 
-    GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
+    GGML_ASSERT(ne0 % (4*QK8_1) == 0);
 
-    const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
-    const dim3 num_blocks(block_num_x, kx1, channels);
+    const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
+    const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
     const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
-    switch (mmq_get_q8_1_ds_layout(type_x)) {
+    switch (mmq_get_q8_1_ds_layout(type_src0)) {
         case MMQ_Q8_1_DS_LAYOUT_D4:
             quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
-                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
             break;
         case MMQ_Q8_1_DS_LAYOUT_DS4:
             quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
-                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
             break;
         case MMQ_Q8_1_DS_LAYOUT_D2S6:
             quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
-                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
             break;
         default:
             GGML_ABORT("fatal error");
             break;
     }
+    GGML_UNUSED(s01);
+    GGML_UNUSED(s02);
+    GGML_UNUSED(s03);
 }
index 03bf322b95873613a1b7dedd6812da1c9eb84868..b627c4e4008b4a3bc010c8251b7ea4fe8702c57c 100644 (file)
@@ -12,13 +12,13 @@ static_assert(MATRIX_ROW_PADDING %    CUDA_QUANTIZE_BLOCK_SIZE      == 0, "Risk
 static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
 
 typedef void (*quantize_cuda_t)(
-    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
-    const ggml_type type_x, cudaStream_t stream);
+    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
 
 void quantize_row_q8_1_cuda(
-    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
-    const ggml_type type_x, cudaStream_t stream);
+    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
 
 void quantize_mmq_q8_1_cuda(
-    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
-    const ggml_type type_x, cudaStream_t stream);
+    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
index 40091a0ef07b43230c6d7391f8b4a1045d1175b4..ba195e1d100d320298cdf12442e8d5de2995e616 100644 (file)
@@ -1,3 +1,5 @@
+#pragma once
+
 #include "common.cuh"
 #include <cstdint>
 
index 5f6f87d1a3a7b7aa23f49cb7e31e65a54d0e30d5..61751755b317b435c0e04d3e0e765ad590ef482d 100644 (file)
@@ -2071,7 +2071,7 @@ struct test_mul_mat_id : public test_case {
     const ggml_type type_b;
     const int n_mats;
     const int n_used;
-    const bool b; // brodcast b matrix
+    const bool b; // broadcast b matrix
     const int64_t m;
     const int64_t n;
     const int64_t k;