]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Fix data race in CUDA's "cpy" kernel (influences GGML's DUP, CONT operations). (llama...
authorRail Chabdarov <redacted>
Sat, 14 Mar 2026 05:19:44 +0000 (06:19 +0100)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* Fix datarace in CUDA's "cpy" kernel.

* Remove extra barrier by using more of shared memory.

ggml/src/ggml-cuda/cpy.cu

index ee84303ef0e99df8dc63f556e4ebdb7c7c6e3aa9..d208acf2d5f01bc8b00ca4b1f1c930e7118cb492 100644 (file)
@@ -56,7 +56,8 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
     const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x;  // transpose block offset
     const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
 
-    __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
+    __shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
+    int cur_tile_buf = 0;
 
 #pragma unroll
     for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
@@ -70,7 +71,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
             if(x < ne01 && y + j < ne00){
                 const int row = threadIdx.y+j;
                 const int col = threadIdx.x * sizeof(float)/sizeof(T);
-                T *tile2 = reinterpret_cast<T*>(tile[row]);
+                T *tile2 = reinterpret_cast<T*>(tile[cur_tile_buf][row]);
                 tile2[col] = src[imat*n + (y+j)*ne01 + x];
             }
         }
@@ -81,10 +82,12 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
         for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
             if (ty + j < ne01 && tx < ne00) {
                 const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
-                const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
+                const T *tile2 = reinterpret_cast<const T*>(tile[cur_tile_buf][threadIdx.x]);
                 dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
             }
         }
+
+        cur_tile_buf = (cur_tile_buf + 1) % 2;
     }
 
     GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,