]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: fix crash on uneven context without FA (llama/16988)
authorJohannes Gäßler <redacted>
Thu, 6 Nov 2025 13:05:47 +0000 (14:05 +0100)
committerGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 21:38:03 +0000 (23:38 +0200)
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmf.cu
ggml/src/ggml-cuda/mmf.cuh
ggml/src/ggml-cuda/mmvf.cu
ggml/src/ggml-cuda/mmvf.cuh

index 415a7e962d7795ac8f7abb932469df8ae3a8a3a6..049aece1b5234944418b76add805283149d1be13 100644 (file)
@@ -2113,7 +2113,7 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
         src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
 
     const int cc      = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
-    use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
+    use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
 
     const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
                        ggml_backend_buft_is_cuda_split(src1->buffer->buft);
@@ -2207,16 +2207,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
             const int cc            = ggml_cuda_info().devices[id].cc;
             const int warp_size     = ggml_cuda_info().devices[id].warp_size;
             use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
-            use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
+            use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
+            use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
             any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
         }
     } else {
         const int cc            = ggml_cuda_info().devices[ctx.device].cc;
         const int warp_size     = ggml_cuda_info().devices[ctx.device].warp_size;
         use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
-        use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
+        use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
+        use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
         any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
     }
 
@@ -2287,7 +2287,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
             return;
         }
 
-        if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) {
+        if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) {
             ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
             return;
         }
index 2b0a61395b4588907d4499d71072a50d5e2fcaae..69a60aceb82b74484a42f78e314368b1d6646f8e 100644 (file)
@@ -119,15 +119,21 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
     }
 }
 
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) {
-
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne,
+        const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) {
     if (ggml_is_quantized(type)) {
         return false;
     }
 
-    if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
+    const size_t ts = ggml_type_size(type);
+    if (src0_ne[0] % (warp_size * (4/ts)) != 0) {
         return false;
     }
+    for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
+        if (src0_nb[i] % (2*ts) != 0) {
+            return false;
+        }
+    }
     if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
         return false;
     }
index f7e46e2f63b2fdf815582c1d043077adb966481b..45724e0911ec8f105e436ab234803ba4882288a0 100644 (file)
@@ -17,7 +17,7 @@ struct mmf_ids_data {
 
 void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
 
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id);
 
 template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
 __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
index 4e31783436d80a4abc06645ecab14b3b3c1a0aee..526d90d7aee52618b84b75b7112edb8f9c1cb680 100644 (file)
@@ -716,10 +716,16 @@ void ggml_cuda_op_mul_mat_vec_f(
     GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
 }
 
-bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {
     if (src0_ne[0] % 2 != 0) {
         return false;
     }
+    const size_t ts = ggml_type_size(type);
+    for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
+        if (src0_nb[i] % (2*ts) != 0) {
+            return false;
+        }
+    }
     switch (type) {
         case GGML_TYPE_F32:
             if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
index a205aa8e4c5380377a1f43d635563bf4a8ae116b..a09fbdc72022e68fc275db9afe8e29fd5835de3b 100644 (file)
@@ -9,4 +9,4 @@ void ggml_cuda_op_mul_mat_vec_f(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream);
 
-bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11);