]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: remove DMMV, consolidate F16 mult mat vec (llama/10318)
authorJohannes Gäßler <redacted>
Sun, 17 Nov 2024 08:09:55 +0000 (09:09 +0100)
committerGeorgi Gerganov <redacted>
Wed, 20 Nov 2024 19:00:08 +0000 (21:00 +0200)
ggml/CMakeLists.txt
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/ggml/CMakeLists.txt
ggml/src/ggml-cuda/mmv.cu [new file with mode: 0644]
ggml/src/ggml-cuda/mmv.cuh [new file with mode: 0644]
ggml/src/ggml-hip/CMakeLists.txt
ggml/src/ggml-musa/ggml/CMakeLists.txt

index c02bfb59d85698c898a32b2f0e74baf43516a6be..b0a3cc8827ee2db5f037c28dc588d9e159687836 100644 (file)
@@ -128,14 +128,9 @@ option(GGML_LLAMAFILE                       "ggml: use LLAMAFILE"
 
 option(GGML_CUDA                            "ggml: use CUDA"                                  OFF)
 option(GGML_MUSA                            "ggml: use MUSA"                                  OFF)
-option(GGML_CUDA_FORCE_DMMV                 "ggml: use dmmv instead of mmvq CUDA kernels"     OFF)
 option(GGML_CUDA_FORCE_MMQ                  "ggml: use mmq kernels instead of cuBLAS"         OFF)
 option(GGML_CUDA_FORCE_CUBLAS               "ggml: always use cuBLAS instead of mmq kernels"  OFF)
-set   (GGML_CUDA_DMMV_X   "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels")
-set   (GGML_CUDA_MMV_Y     "1" CACHE STRING "ggml: y block size for mmv CUDA kernels")
 option(GGML_CUDA_F16                        "ggml: use 16 bit floats for some calculations"   OFF)
-set   (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING
-                                            "ggml: iters./thread per block for Q2_K/Q6_K")
 set   (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
                                             "ggml: max. batch size for using peer access")
 option(GGML_CUDA_NO_PEER_COPY               "ggml: do not use peer to peer copies"            OFF)
index 07f043328a8646929060c3bba997ddeb669230bf..ef56e944d01ab6b640f7b765ed8f281b594e527a 100644 (file)
 #include "ggml-cuda/cpy.cuh"
 #include "ggml-cuda/cross-entropy-loss.cuh"
 #include "ggml-cuda/diagmask.cuh"
-#include "ggml-cuda/dmmv.cuh"
 #include "ggml-cuda/fattn.cuh"
 #include "ggml-cuda/getrows.cuh"
 #include "ggml-cuda/im2col.cuh"
 #include "ggml-cuda/mmq.cuh"
+#include "ggml-cuda/mmv.cuh"
 #include "ggml-cuda/mmvq.cuh"
 #include "ggml-cuda/norm.cuh"
 #include "ggml-cuda/opt-step-adamw.cuh"
@@ -1020,114 +1020,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
 
 #define MUL_MAT_SRC1_COL_STRIDE 128
 
-static __global__ void mul_mat_p021_f16_f32(
-    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
-
-    const half * x = (const half *) vx;
-
-    const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
-    const int channel = blockDim.z*blockIdx.z + threadIdx.z;
-    const int channel_x = channel / (nchannels_y / nchannels_x);
-
-    const int nrows_y = ncols_x;
-    const int nrows_dst = nrows_x;
-    const int row_dst = row_x;
-
-    float tmp = 0.0f;
-
-    for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
-        const int col_x = col_x0 + threadIdx.x;
-
-        if (col_x >= ncols_x) {
-            break;
-        }
-
-        // x is transposed and permuted
-        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
-        const float xi = __half2float(x[ix]);
-
-        const int row_y = col_x;
-
-        // y is not transposed but permuted
-        const int iy = channel*nrows_y + row_y;
-
-        tmp += xi * y[iy];
-    }
-
-    // dst is not transposed and not permuted
-    const int idst = channel*nrows_dst + row_dst;
-
-    // sum up partial sums and write back result
-    tmp = warp_reduce_sum(tmp);
-
-    if (threadIdx.x == 0) {
-        dst[idst] = tmp;
-    }
-}
-
-static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
-    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
-
-    const half * x = (const half *) vx;
-
-    const int row_x     = blockDim.y*blockIdx.y + threadIdx.y;
-    const int channel   = blockDim.z*blockIdx.z + threadIdx.z;
-    const int channel_x = channel / channel_x_divisor;
-
-    const int nrows_y   = ncols_x;
-    const int nrows_dst = nrows_x;
-    const int row_dst   = row_x;
-
-    const int idst = channel*nrows_dst + row_dst;
-
-    float tmp = 0.0f;
-
-    for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
-        const int col_x = col_x0 + threadIdx.x;
-
-        if (col_x >= ncols_x) {
-            break;
-        }
-
-        const int row_y = col_x;
-
-        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
-        const int iy = channel*nrows_y + row_y;
-
-        const float xi = __half2float(x[ix]);
-
-        tmp += xi * y[iy];
-    }
-
-    // sum up partial sums and write back result
-    tmp = warp_reduce_sum(tmp);
-
-    if (threadIdx.x == 0) {
-        dst[idst] = tmp;
-    }
-}
-
-static void ggml_mul_mat_p021_f16_f32_cuda(
-    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
-    const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
-
-    const dim3 block_nums(1, nrows_x, nchannels_y);
-    const dim3 block_dims(WARP_SIZE, 1, 1);
-    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
-}
-
-static void ggml_mul_mat_vec_nc_f16_f32_cuda(
-    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
-    const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
-
-    const dim3 block_nums(1, nrows_x, nchannels_y);
-    const dim3 block_dims(WARP_SIZE, 1, 1);
-    mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
-        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
-}
-
 static cudaError_t ggml_cuda_cpy_tensor_2d(
     void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
 
@@ -1654,58 +1546,6 @@ static void ggml_cuda_op_mul_mat(
     }
 }
 
-static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
-    GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
-    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
-    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-
-    const int64_t ne12 = src1->ne[2];
-
-    cudaStream_t main_stream = ctx.stream();
-
-    void  * src0_ddq = src0->data;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf  = (float *) dst->data;
-
-    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
-}
-
-static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(!ggml_is_transposed(src0));
-    GGML_ASSERT(!ggml_is_transposed(src1));
-    GGML_ASSERT(!ggml_is_permuted(src0));
-    GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-
-    const int64_t nb01 = src0->nb[1];
-    const int64_t nb02 = src0->nb[2];
-
-    const int64_t ne12 = src1->ne[2];
-
-    cudaStream_t main_stream = ctx.stream();
-
-    void  * src0_ddq = src0->data;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf  = (float *) dst->data;
-
-    const int64_t row_stride_x = nb01 / sizeof(half);
-    const int64_t channel_stride_x = nb02 / sizeof(half);
-
-    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
-}
-
 static __global__ void k_compute_batched_ptrs(
         const half * src0_as_f16, const half * src1_as_f16, char * dst,
         const void ** ptrs_src, void ** ptrs_dst,
@@ -1879,21 +1719,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
 
-    bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
+    bool use_mul_mat_vec   = src0->type == GGML_TYPE_F16
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
-    bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
+        && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
+    bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
         && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
-    bool              use_mul_mat_q =  ggml_is_quantized(src0->type)
+    bool use_mul_mat_q     = ggml_is_quantized(src0->type)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
 
-    // if mmvq is available it's a better choice than dmmv:
-#ifndef GGML_CUDA_FORCE_DMMV
-    use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
-#endif // GGML_CUDA_FORCE_DMMV
-
-    bool any_gpus_with_slow_fp16 = false;
+    bool any_gpus_with_slow_fp16   = false;
+    bool any_gpus_without_fp16_mma = false;
 
     if (split) {
         ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1904,14 +1740,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
                 continue;
             }
 
-            const int cc            = ggml_cuda_info().devices[id].cc;
-            use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
+            const int cc              = ggml_cuda_info().devices[id].cc;
+            use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+            any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_available(cc);
+            any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
         }
     } else {
-        const int cc            = ggml_cuda_info().devices[ctx.device].cc;
-        use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
+        const int cc              = ggml_cuda_info().devices[ctx.device].cc;
+        use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+        any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_available(cc);
+        any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
     }
 
     // debug helpers
@@ -1922,18 +1760,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
-        // FP32 precision KQ single-batch for batch size 1 without FlashAttention
-        ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
-    } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
-        // FP32 precision KQV single-batch for batch size 1 without FlashAttention
-        ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
+    if (!split && src0->type == GGML_TYPE_F16 && src1->ne[1] == 1 && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
+        ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
     } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
                && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
-    } else if (use_dequantize_mul_mat_vec) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
+    } else if (use_mul_mat_vec) {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
     } else if (use_mul_mat_vec_q) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
     } else if (use_mul_mat_q) {
index 860552f3a96df86f205e01bebfabe831b10ff3b5..3dde0f3666187660be09795eae2f7568ef584c41 100644 (file)
@@ -54,21 +54,12 @@ if (CUDAToolkit_FOUND)
     target_link_libraries(ggml-cuda PRIVATE ggml-base)
     target_include_directories(ggml-cuda PRIVATE . ..)
 
-    # TODO: change the definitions to this target only
-
-    add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
-    add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
-    add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
     add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
 
     if (GGML_CUDA_GRAPHS)
         add_compile_definitions(GGML_CUDA_USE_GRAPHS)
     endif()
 
-    if (GGML_CUDA_FORCE_DMMV)
-        add_compile_definitions(GGML_CUDA_FORCE_DMMV)
-    endif()
-
     if (GGML_CUDA_FORCE_MMQ)
         add_compile_definitions(GGML_CUDA_FORCE_MMQ)
     endif()
@@ -81,10 +72,6 @@ if (CUDAToolkit_FOUND)
         add_compile_definitions(GGML_CUDA_NO_VMM)
     endif()
 
-    if (DEFINED GGML_CUDA_DMMV_Y)
-        add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility
-    endif()
-
     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
         add_compile_definitions(GGML_CUDA_F16)
     endif()
diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu
new file mode 100644 (file)
index 0000000..cfe91f4
--- /dev/null
@@ -0,0 +1,223 @@
+#include "common.cuh"
+#include "mmv.cuh"
+
+template <typename type_acc, int block_size>
+static __global__ void mul_mat_vec(
+        const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
+        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
+    const int64_t row     = blockIdx.x;
+    const int64_t channel = blockIdx.z;
+    const int     tid     = threadIdx.x;
+
+    x   += (channel/channel_ratio)*stride_channel_x + row*stride_row;
+    y   +=  channel               *stride_channel_y;
+    dst +=  channel               *stride_channel_dst;
+
+    const half2  * x2 = (const half2  *) x;
+    const float2 * y2 = (const float2 *) y;
+
+    extern __shared__ char data_mmv[];
+    float * buf_iw = (float *) data_mmv;
+
+    if (block_size > WARP_SIZE) {
+        if (tid < WARP_SIZE) {
+            buf_iw[tid] = 0.0f;
+        }
+        __syncthreads();
+    }
+
+    float sumf;
+
+    if (std::is_same<type_acc, float>::value) {
+        sumf = 0.0f;
+
+        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+            const float2 tmpx = __half22float2(x2[col2]);
+            const float2 tmpy = y2[col2];
+            sumf += tmpx.x * tmpy.x;
+            sumf += tmpx.y * tmpy.y;
+        }
+    } else {
+#ifdef FP16_AVAILABLE
+        half2 sumh2 = make_half2(0.0f, 0.0f);
+
+        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+            const float2 tmp = y2[col2];
+            sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
+        }
+
+        sumf = __low2float(sumh2) + __high2float(sumh2);
+#else
+        NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+    }
+
+    sumf = warp_reduce_sum(sumf);
+
+    if (block_size > WARP_SIZE) {
+        buf_iw[tid/WARP_SIZE] = sumf;
+        __syncthreads();
+        if (tid > WARP_SIZE) {
+            return;
+        }
+        sumf = buf_iw[tid];
+        sumf = warp_reduce_sum(sumf);
+    }
+
+    if (tid != 0) {
+        return;
+    }
+
+    dst[row] = sumf;
+}
+
+template <typename type_acc>
+static void launch_mul_mat_vec_cuda(
+        const half * x, const float * y, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        cudaStream_t stream) {
+    GGML_ASSERT(ncols      % 2 == 0);
+    GGML_ASSERT(stride_row % 2 == 0);
+    GGML_ASSERT(nchannels_y % nchannels_x == 0);
+    const int64_t channel_ratio = nchannels_y / nchannels_x;
+
+    int64_t block_size_best = WARP_SIZE;
+    int64_t niter_best      = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
+    for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
+        const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
+        if (niter < niter_best) {
+            niter_best      = niter;
+            block_size_best = block_size;
+        }
+    }
+
+    const int smem = WARP_SIZE*sizeof(float);
+    const dim3 block_nums(nrows, 1, nchannels_y);
+    const dim3 block_dims(block_size_best, 1, 1);
+    switch (block_size_best) {
+        case   32: {
+            mul_mat_vec<type_acc,  32><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case   64: {
+            mul_mat_vec<type_acc,  64><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case   96: {
+            mul_mat_vec<type_acc,  96><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  128: {
+            mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  160: {
+            mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  192: {
+            mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  224: {
+            mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  256: {
+            mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+static void mul_mat_vec_cuda(
+        const half * x, const float * y, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        enum ggml_prec prec, cudaStream_t stream) {
+    switch (prec) {
+        case GGML_PREC_DEFAULT: {
+            launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+        } break;
+        case GGML_PREC_F32: {
+            launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+        } break;
+    }
+}
+
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+
+    GGML_ASSERT(src1->ne[1] == 1);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+    const half  * src0_d = (const half  *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       *  dst_d = (float       *)  dst->data;
+
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne12 = src1->ne[2];
+    GGML_ASSERT(dst->ne[2] == ne12);
+
+    GGML_ASSERT(src0->ne[3] == 1);
+    GGML_ASSERT(src1->ne[3] == 1);
+    GGML_ASSERT( dst->ne[3] == 1);
+
+    const int64_t stride_row         = src0->nb[1] / ggml_type_size(src0->type);
+    const int64_t channel_stride_x   = src0->nb[2] / ggml_type_size(src0->type);
+    const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type);
+    const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type);
+
+    mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+}
+
+void ggml_cuda_op_mul_mat_vec(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    GGML_ASSERT(src1_ncols == 1);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+
+    // ggml_cuda_op provides single, contiguous matrices
+    const int64_t stride_row         = ne00;
+    const int64_t nchannels_x        = 1;
+    const int64_t nchannels_y        = 1;
+    const int64_t channel_stride_x   = 0;
+    const int64_t channel_stride_y   = 0;
+    const int64_t channel_stride_dst = 0;
+
+    mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+        nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddq_i);
+    GGML_UNUSED(src1_ncols);
+    GGML_UNUSED(src1_padded_row_size);
+}
diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmv.cuh
new file mode 100644 (file)
index 0000000..78a1cd4
--- /dev/null
@@ -0,0 +1,12 @@
+#include "common.cuh"
+
+// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
+#define MMV_MAX_ROWS 512
+
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_cuda_op_mul_mat_vec(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);
index 5ed186ded16be2eda23b5ee5a52a8031f06eca20..fccf8eb8440b8822f3e98ffa3f2848d7ecfd0e9d 100644 (file)
@@ -75,18 +75,11 @@ target_include_directories(ggml-hip PRIVATE . ..)
 target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
 
 add_compile_definitions(GGML_USE_HIP)
-add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
-add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
-add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
 
 if (GGML_HIP_UMA)
     add_compile_definitions(GGML_HIP_UMA)
 endif()
 
-if (GGML_CUDA_FORCE_DMMV)
-    add_compile_definitions(GGML_CUDA_FORCE_DMMV)
-endif()
-
 if (GGML_CUDA_FORCE_MMQ)
     add_compile_definitions(GGML_CUDA_FORCE_MMQ)
 endif()
index 8edc75cc5a9c7ab0488febcc5864e3195343b73d..f3c013692054071f72ff563d3cc0c9c8a4dd50de 100644 (file)
@@ -58,19 +58,12 @@ if (MUSAToolkit_FOUND)
     target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
 
     add_compile_definitions(GGML_USE_MUSA)
-    add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
-    add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
-    add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
     add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
 
     if (GGML_CUDA_GRAPHS)
         add_compile_definitions(GGML_CUDA_USE_GRAPHS)
     endif()
 
-    if (GGML_CUDA_FORCE_DMMV)
-        add_compile_definitions(GGML_CUDA_FORCE_DMMV)
-    endif()
-
     if (GGML_CUDA_FORCE_MMQ)
         add_compile_definitions(GGML_CUDA_FORCE_MMQ)
     endif()
@@ -83,10 +76,6 @@ if (MUSAToolkit_FOUND)
         add_compile_definitions(GGML_CUDA_NO_VMM)
     endif()
 
-    if (DEFINED GGML_CUDA_DMMV_Y)
-        add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility
-    endif()
-
     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
         add_compile_definitions(GGML_CUDA_F16)
     endif()