]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda/hip: fix loop unrolling in ssm-conv (llama/20369)
authoruvos <redacted>
Wed, 11 Mar 2026 05:04:32 +0000 (06:04 +0100)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
ggml/src/ggml-cuda/ssm-conv.cu

index 85e82b5a422961823d290155b8120cf431641233..69985cd335c9f27c53dae43d370c169551e29eb7 100644 (file)
@@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
     int row = tid / load_cols;
     int col = tid % load_cols;
 #pragma unroll
-    for (int idx = tid; idx < total_elems; idx += split_d_inner) {
+    for (int idx = 0; idx < total_elems; idx += split_d_inner) {
         if (row < (int)split_d_inner) {
             smem[row * n_cols + col] = x_block[row * stride_x + col];
         }
@@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
         col += split_d_inner;
         row += col / load_cols;
         col  = col % load_cols;
+        if (idx >= total_elems - tid - split_d_inner) {
+            break;
+        }
     }
     __syncthreads();