From: Aadeshveer Singh Date: Wed, 17 Dec 2025 03:47:01 +0000 (+0530) Subject: ggml : use WARP_SIZE/2 for argmax reduction offset (llama/18092) X-Git-Tag: upstream/1.8.3~124 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=41a95b8ba798cece458644337c08f82772a8942d;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp ggml : use WARP_SIZE/2 for argmax reduction offset (llama/18092) --- diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index 5340eedc..51967c66 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest } #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) { const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) { @@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest argmax = shared_argmax[lane_id]; } #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) { const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) {