From: Johannes Gäßler Date: Sun, 4 May 2025 12:16:39 +0000 (+0200) Subject: CUDA: fix race condition in MMQ stream-k fixup (#13299) X-Git-Tag: upstream/0.0.5318~43 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=93c4e23905987949b714b21ae918ff6bfb55fe36;p=pkg%2Fggml%2Fsources%2Fllama.cpp CUDA: fix race condition in MMQ stream-k fixup (#13299) --- diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index fc6ce008..e1096dce 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2958,6 +2958,7 @@ static __global__ void mul_mat_q_stream_k_fixup( for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) { ids_dst_shared[j] = ids_dst[col_low + j]; } + __syncthreads(); const int offset_dst = it*mmq_y; dst += offset_dst;