From: Aadeshveer Singh Date: Wed, 17 Dec 2025 03:47:01 +0000 (+0530) Subject: ggml : use WARP_SIZE/2 for argmax reduction offset (llama/18092) X-Git-Tag: upstream/0.9.4.395~2 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=23c6c8ae9f0243b4e25a666ace24c537dfb6ca0e;p=pkg%2Fggml%2Fsources%2Fggml ggml : use WARP_SIZE/2 for argmax reduction offset (llama/18092) --- diff --git a/src/ggml-cuda/argmax.cu b/src/ggml-cuda/argmax.cu index 5340eedc..51967c66 100644 --- a/src/ggml-cuda/argmax.cu +++ b/src/ggml-cuda/argmax.cu @@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest } #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) { const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) { @@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest argmax = shared_argmax[lane_id]; } #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) { const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) {