From: Johannes Gäßler Date: Thu, 6 Nov 2025 13:05:47 +0000 (+0100) Subject: CUDA: fix crash on uneven context without FA (llama/16988) X-Git-Tag: upstream/1.8.3~361 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=b5d6fa438f51ad8f70a71a04b77784268fc0065c;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp CUDA: fix crash on uneven context without FA (llama/16988) --- diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 415a7e96..049aece1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2113,7 +2113,7 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) { src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]); const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) || ggml_backend_buft_is_cuda_split(src1->buffer->buft); @@ -2207,16 +2207,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const int cc = ggml_cuda_info().devices[id].cc; const int warp_size = ggml_cuda_info().devices[id].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false); - use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } } else { const int cc = ggml_cuda_info().devices[ctx.device].cc; const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false); - use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } @@ -2287,7 +2287,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * return; } - if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) { + if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) { ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst); return; } diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 2b0a6139..69a60ace 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -119,15 +119,21 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr } } -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) { - +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, + const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) { if (ggml_is_quantized(type)) { return false; } - if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { + const size_t ts = ggml_type_size(type); + if (src0_ne[0] % (warp_size * (4/ts)) != 0) { return false; } + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + if (src0_nb[i] % (2*ts) != 0) { + return false; + } + } if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index f7e46e2f..45724e09 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -17,7 +17,7 @@ struct mmf_ids_data { void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id); template __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 4e317834..526d90d7 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -716,10 +716,16 @@ void ggml_cuda_op_mul_mat_vec_f( GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size); } -bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { +bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) { if (src0_ne[0] % 2 != 0) { return false; } + const size_t ts = ggml_type_size(type); + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + if (src0_nb[i] % (2*ts) != 0) { + return false; + } + } switch (type) { case GGML_TYPE_F32: if (GGML_CUDA_CC_IS_NVIDIA(cc)) { diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh index a205aa8e..a09fbdc7 100644 --- a/ggml/src/ggml-cuda/mmvf.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -9,4 +9,4 @@ void ggml_cuda_op_mul_mat_vec_f( const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); -bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); +bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11);