]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : use WARP_SIZE/2 for argmax reduction offset (llama/18092)
authorAadeshveer Singh <redacted>
Wed, 17 Dec 2025 03:47:01 +0000 (09:17 +0530)
committerGeorgi Gerganov <redacted>
Wed, 17 Dec 2025 11:55:04 +0000 (13:55 +0200)
src/ggml-cuda/argmax.cu

index 5340eedc08916cdaace562b248ef54d900ce2021..51967c667cfd88a8d9d81ddbf827bce402d65761 100644 (file)
@@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
     }
 
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
+    for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
         const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
         const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
         if (val > maxval) {
@@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
                 argmax = shared_argmax[lane_id];
             }
 #pragma unroll
-            for (int offset = 16; offset > 0; offset >>= 1) {
+            for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
                 const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
                 const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
                 if (val > maxval) {