]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA/HIP: fix ssm_scan on devices where warp size is not 32 (llama/14196)
authoruvos <redacted>
Sun, 15 Jun 2025 15:30:13 +0000 (17:30 +0200)
committerGeorgi Gerganov <redacted>
Wed, 18 Jun 2025 09:40:34 +0000 (12:40 +0300)
ggml/src/ggml-cuda/ssm-scan.cu

index 37ee208c09d46c9fc7680eac9d333a3597c35415..2d34b836054f8eaed4e3decda8f1fe55854beceb 100644 (file)
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
                  float * __restrict__ dst, const int64_t L) {
     GGML_UNUSED(src1_nb0);
     GGML_UNUSED(src2_nb0);
+
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     const int bidx = blockIdx.x;  // split along B
     const int bidy = blockIdx.y;  // split along D
     const int tid  = threadIdx.x;
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
     if (N == 16) {
 #pragma unroll
         for (size_t i = 0; i < splitD / 4; i += 2) {
-            float value = A_block[(wid * warpSize + i) * stride_A + wtid];
+            float value = A_block[(wid * warp_size + i) * stride_A + wtid];
             // todo: bank conflict
             // I am always confused with how to use the swizzling method to solve
             // bank conflit. Hoping somebody can tell me.
-            smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
+            smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
         }
 #pragma unroll
         for (size_t i = 0; i < splitD / 4; i += 2) {
-            float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
-            smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
+            float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
+            smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
         }
     }