option(GGML_CUDA "ggml: use CUDA" OFF)
option(GGML_MUSA "ggml: use MUSA" OFF)
-option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
-set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels")
-set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels")
option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF)
-set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING
- "ggml: iters./thread per block for Q2_K/Q6_K")
set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"ggml: max. batch size for using peer access")
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
target_link_libraries(ggml-cuda PRIVATE ggml-base)
target_include_directories(ggml-cuda PRIVATE . ..)
- # TODO: change the definitions to this target only
-
- add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
- add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
- add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
if (GGML_CUDA_GRAPHS)
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
endif()
- if (GGML_CUDA_FORCE_DMMV)
- add_compile_definitions(GGML_CUDA_FORCE_DMMV)
- endif()
-
if (GGML_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()
add_compile_definitions(GGML_CUDA_NO_VMM)
endif()
- if (DEFINED GGML_CUDA_DMMV_Y)
- add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility
- endif()
-
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
add_compile_definitions(GGML_CUDA_F16)
endif()
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/diagmask.cuh"
-#include "ggml-cuda/dmmv.cuh"
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
+#include "ggml-cuda/mmv.cuh"
#include "ggml-cuda/mmvq.cuh"
#include "ggml-cuda/norm.cuh"
#include "ggml-cuda/opt-step-adamw.cuh"
#define MUL_MAT_SRC1_COL_STRIDE 128
-static __global__ void mul_mat_p021_f16_f32(
- const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
-
- const half * x = (const half *) vx;
-
- const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
- const int channel = blockDim.z*blockIdx.z + threadIdx.z;
- const int channel_x = channel / (nchannels_y / nchannels_x);
-
- const int nrows_y = ncols_x;
- const int nrows_dst = nrows_x;
- const int row_dst = row_x;
-
- float tmp = 0.0f;
-
- for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
- const int col_x = col_x0 + threadIdx.x;
-
- if (col_x >= ncols_x) {
- break;
- }
-
- // x is transposed and permuted
- const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
- const float xi = __half2float(x[ix]);
-
- const int row_y = col_x;
-
- // y is not transposed but permuted
- const int iy = channel*nrows_y + row_y;
-
- tmp += xi * y[iy];
- }
-
- // dst is not transposed and not permuted
- const int idst = channel*nrows_dst + row_dst;
-
- // sum up partial sums and write back result
- tmp = warp_reduce_sum(tmp);
-
- if (threadIdx.x == 0) {
- dst[idst] = tmp;
- }
-}
-
-static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
- const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
- const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
-
- const half * x = (const half *) vx;
-
- const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
- const int channel = blockDim.z*blockIdx.z + threadIdx.z;
- const int channel_x = channel / channel_x_divisor;
-
- const int nrows_y = ncols_x;
- const int nrows_dst = nrows_x;
- const int row_dst = row_x;
-
- const int idst = channel*nrows_dst + row_dst;
-
- float tmp = 0.0f;
-
- for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
- const int col_x = col_x0 + threadIdx.x;
-
- if (col_x >= ncols_x) {
- break;
- }
-
- const int row_y = col_x;
-
- const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
- const int iy = channel*nrows_y + row_y;
-
- const float xi = __half2float(x[ix]);
-
- tmp += xi * y[iy];
- }
-
- // sum up partial sums and write back result
- tmp = warp_reduce_sum(tmp);
-
- if (threadIdx.x == 0) {
- dst[idst] = tmp;
- }
-}
-
-static void ggml_mul_mat_p021_f16_f32_cuda(
- const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
- const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
-
- const dim3 block_nums(1, nrows_x, nchannels_y);
- const dim3 block_dims(WARP_SIZE, 1, 1);
- mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
-}
-
-static void ggml_mul_mat_vec_nc_f16_f32_cuda(
- const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
- const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
-
- const dim3 block_nums(1, nrows_x, nchannels_y);
- const dim3 block_dims(WARP_SIZE, 1, 1);
- mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
- (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
-}
-
static cudaError_t ggml_cuda_cpy_tensor_2d(
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
}
}
-static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
- GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
- GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
-
- const int64_t ne12 = src1->ne[2];
-
- cudaStream_t main_stream = ctx.stream();
-
- void * src0_ddq = src0->data;
- float * src1_ddf = (float *) src1->data;
- float * dst_ddf = (float *) dst->data;
-
- ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
-}
-
-static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(!ggml_is_transposed(src0));
- GGML_ASSERT(!ggml_is_transposed(src1));
- GGML_ASSERT(!ggml_is_permuted(src0));
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
-
- const int64_t nb01 = src0->nb[1];
- const int64_t nb02 = src0->nb[2];
-
- const int64_t ne12 = src1->ne[2];
-
- cudaStream_t main_stream = ctx.stream();
-
- void * src0_ddq = src0->data;
- float * src1_ddf = (float *) src1->data;
- float * dst_ddf = (float *) dst->data;
-
- const int64_t row_stride_x = nb01 / sizeof(half);
- const int64_t channel_stride_x = nb02 / sizeof(half);
-
- ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
-}
-
static __global__ void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, char * dst,
const void ** ptrs_src, void ** ptrs_dst,
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
- bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
+ bool use_mul_mat_vec = src0->type == GGML_TYPE_F16
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
- && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
- bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
+ && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
- bool use_mul_mat_q = ggml_is_quantized(src0->type)
+ bool use_mul_mat_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
- // if mmvq is available it's a better choice than dmmv:
-#ifndef GGML_CUDA_FORCE_DMMV
- use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
-#endif // GGML_CUDA_FORCE_DMMV
-
- bool any_gpus_with_slow_fp16 = false;
+ bool any_gpus_with_slow_fp16 = false;
+ bool any_gpus_without_fp16_mma = false;
if (split) {
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
continue;
}
- const int cc = ggml_cuda_info().devices[id].cc;
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
+ const int cc = ggml_cuda_info().devices[id].cc;
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
+ any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
}
} else {
- const int cc = ggml_cuda_info().devices[ctx.device].cc;
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
+ const int cc = ggml_cuda_info().devices[ctx.device].cc;
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
+ any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
}
// debug helpers
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
- if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
- // FP32 precision KQ single-batch for batch size 1 without FlashAttention
- ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
- } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
- // FP32 precision KQV single-batch for batch size 1 without FlashAttention
- ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
+ if (!split && src0->type == GGML_TYPE_F16 && src1->ne[1] == 1 && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
+ ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
- } else if (use_dequantize_mul_mat_vec) {
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
+ } else if (use_mul_mat_vec) {
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
--- /dev/null
+#include "common.cuh"
+#include "mmv.cuh"
+
+template <typename type_acc, int block_size>
+static __global__ void mul_mat_vec(
+ const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
+ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
+ const int64_t row = blockIdx.x;
+ const int64_t channel = blockIdx.z;
+ const int tid = threadIdx.x;
+
+ x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
+ y += channel *stride_channel_y;
+ dst += channel *stride_channel_dst;
+
+ const half2 * x2 = (const half2 *) x;
+ const float2 * y2 = (const float2 *) y;
+
+ extern __shared__ char data_mmv[];
+ float * buf_iw = (float *) data_mmv;
+
+ if (block_size > WARP_SIZE) {
+ if (tid < WARP_SIZE) {
+ buf_iw[tid] = 0.0f;
+ }
+ __syncthreads();
+ }
+
+ float sumf;
+
+ if (std::is_same<type_acc, float>::value) {
+ sumf = 0.0f;
+
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+ const float2 tmpx = __half22float2(x2[col2]);
+ const float2 tmpy = y2[col2];
+ sumf += tmpx.x * tmpy.x;
+ sumf += tmpx.y * tmpy.y;
+ }
+ } else {
+#ifdef FP16_AVAILABLE
+ half2 sumh2 = make_half2(0.0f, 0.0f);
+
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+ const float2 tmp = y2[col2];
+ sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
+ }
+
+ sumf = __low2float(sumh2) + __high2float(sumh2);
+#else
+ NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+ }
+
+ sumf = warp_reduce_sum(sumf);
+
+ if (block_size > WARP_SIZE) {
+ buf_iw[tid/WARP_SIZE] = sumf;
+ __syncthreads();
+ if (tid > WARP_SIZE) {
+ return;
+ }
+ sumf = buf_iw[tid];
+ sumf = warp_reduce_sum(sumf);
+ }
+
+ if (tid != 0) {
+ return;
+ }
+
+ dst[row] = sumf;
+}
+
+template <typename type_acc>
+static void launch_mul_mat_vec_cuda(
+ const half * x, const float * y, float * dst,
+ const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+ cudaStream_t stream) {
+ GGML_ASSERT(ncols % 2 == 0);
+ GGML_ASSERT(stride_row % 2 == 0);
+ GGML_ASSERT(nchannels_y % nchannels_x == 0);
+ const int64_t channel_ratio = nchannels_y / nchannels_x;
+
+ int64_t block_size_best = WARP_SIZE;
+ int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
+ for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
+ const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
+ if (niter < niter_best) {
+ niter_best = niter;
+ block_size_best = block_size;
+ }
+ }
+
+ const int smem = WARP_SIZE*sizeof(float);
+ const dim3 block_nums(nrows, 1, nchannels_y);
+ const dim3 block_dims(block_size_best, 1, 1);
+ switch (block_size_best) {
+ case 32: {
+ mul_mat_vec<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
+ (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+ } break;
+ case 64: {
+ mul_mat_vec<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
+ (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+ } break;
+ case 96: {
+ mul_mat_vec<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
+ (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+ } break;
+ case 128: {
+ mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
+ (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+ } break;
+ case 160: {
+ mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
+ (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+ } break;
+ case 192: {
+ mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
+ (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+ } break;
+ case 224: {
+ mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
+ (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+ } break;
+ case 256: {
+ mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
+ (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ } break;
+ }
+}
+
+static void mul_mat_vec_cuda(
+ const half * x, const float * y, float * dst,
+ const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+ enum ggml_prec prec, cudaStream_t stream) {
+ switch (prec) {
+ case GGML_PREC_DEFAULT: {
+ launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+ stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+ } break;
+ case GGML_PREC_F32: {
+ launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+ stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+ } break;
+ }
+}
+
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+
+ GGML_ASSERT(src1->ne[1] == 1);
+
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+ const half * src0_d = (const half *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ const int64_t ne02 = src0->ne[2];
+ const int64_t ne12 = src1->ne[2];
+ GGML_ASSERT(dst->ne[2] == ne12);
+
+ GGML_ASSERT(src0->ne[3] == 1);
+ GGML_ASSERT(src1->ne[3] == 1);
+ GGML_ASSERT( dst->ne[3] == 1);
+
+ const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type);
+ const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type);
+ const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
+ const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
+
+ mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+}
+
+void ggml_cuda_op_mul_mat_vec(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t row_diff = row_high - row_low;
+
+ GGML_ASSERT(src1_ncols == 1);
+
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+
+ // ggml_cuda_op provides single, contiguous matrices
+ const int64_t stride_row = ne00;
+ const int64_t nchannels_x = 1;
+ const int64_t nchannels_y = 1;
+ const int64_t channel_stride_x = 0;
+ const int64_t channel_stride_y = 0;
+ const int64_t channel_stride_dst = 0;
+
+ mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+ nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+
+ GGML_UNUSED(ctx);
+ GGML_UNUSED(src1);
+ GGML_UNUSED(dst);
+ GGML_UNUSED(src1_ddq_i);
+ GGML_UNUSED(src1_ncols);
+ GGML_UNUSED(src1_padded_row_size);
+}
--- /dev/null
+#include "common.cuh"
+
+// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
+#define MMV_MAX_ROWS 512
+
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_cuda_op_mul_mat_vec(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
add_compile_definitions(GGML_USE_HIP)
-add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
-add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
-add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
if (GGML_HIP_UMA)
add_compile_definitions(GGML_HIP_UMA)
endif()
-if (GGML_CUDA_FORCE_DMMV)
- add_compile_definitions(GGML_CUDA_FORCE_DMMV)
-endif()
-
if (GGML_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()
target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
add_compile_definitions(GGML_USE_MUSA)
- add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
- add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
- add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
if (GGML_CUDA_GRAPHS)
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
endif()
- if (GGML_CUDA_FORCE_DMMV)
- add_compile_definitions(GGML_CUDA_FORCE_DMMV)
- endif()
-
if (GGML_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()
add_compile_definitions(GGML_CUDA_NO_VMM)
endif()
- if (DEFINED GGML_CUDA_DMMV_Y)
- add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility
- endif()
-
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
add_compile_definitions(GGML_CUDA_F16)
endif()