]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : fix bounds check for src0 rows in MMVQ kernel (whisper/2231)
authorGeorgi Gerganov <redacted>
Tue, 11 Jun 2024 14:39:01 +0000 (17:39 +0300)
committerGeorgi Gerganov <redacted>
Sun, 16 Jun 2024 17:30:48 +0000 (20:30 +0300)
* cuda : fix bounds check for src0 rows in MMVQ kernel

* Update ggml-cuda/mmvq.cu

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
src/ggml-cuda/mmvq.cu

index 5f056e91e54606a163497b88057484039feb3a3d..e8d157169544f75cee2d59963c90b5e768a37e0c 100644 (file)
@@ -117,7 +117,7 @@ static __global__ void mul_mat_vec_q(
             tmp[j][i] = warp_reduce_sum(tmp[j][i]);
         }
 
-        if (threadIdx.x < rows_per_cuda_block) {
+        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
             dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
         }
     }