]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda/hip: fix loop unrolling in ssm-conv (llama/20369)
authoruvos <redacted>
Wed, 11 Mar 2026 05:04:32 +0000 (06:04 +0100)
committerGeorgi Gerganov <redacted>
Sun, 15 Mar 2026 19:50:13 +0000 (21:50 +0200)
src/ggml-cuda/ssm-conv.cu

index 85e82b5a422961823d290155b8120cf431641233..69985cd335c9f27c53dae43d370c169551e29eb7 100644 (file)
@@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
     int row = tid / load_cols;
     int col = tid % load_cols;
 #pragma unroll
-    for (int idx = tid; idx < total_elems; idx += split_d_inner) {
+    for (int idx = 0; idx < total_elems; idx += split_d_inner) {
         if (row < (int)split_d_inner) {
             smem[row * n_cols + col] = x_block[row * stride_x + col];
         }
@@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
         col += split_d_inner;
         row += col / load_cols;
         col  = col % load_cols;
+        if (idx >= total_elems - tid - split_d_inner) {
+            break;
+        }
     }
     __syncthreads();