From: Sigbjørn Skjæret Date: Thu, 30 Oct 2025 07:56:28 +0000 (+0100) Subject: cuda : fix argsort with 64k+ rows (llama/16849) X-Git-Tag: upstream/1.8.3~394 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=f7dfa39104dbb756fc0d839698edaffaf3c7ddaa;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp cuda : fix argsort with 64k+ rows (llama/16849) --- diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 6e7b90d4..3722cf3a 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -87,7 +87,7 @@ template static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -151,7 +151,7 @@ static void argsort_f32_i32_cuda_bitonic(const float * x, const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = ncols_pad * sizeof(int); // FIXME: this limit could be raised by ~2-4x on Ampere or newer