From: Sigbjørn Skjæret Date: Thu, 30 Oct 2025 07:56:28 +0000 (+0100) Subject: cuda : fix argsort with 64k+ rows (llama/16849) X-Git-Tag: upstream/0.9.4.185~58 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=5f7ee94439f9a9f3da981400120528073417752c;p=pkg%2Fggml%2Fsources%2Fggml cuda : fix argsort with 64k+ rows (llama/16849) --- diff --git a/src/ggml-cuda/argsort.cu b/src/ggml-cuda/argsort.cu index 6e7b90d4..3722cf3a 100644 --- a/src/ggml-cuda/argsort.cu +++ b/src/ggml-cuda/argsort.cu @@ -87,7 +87,7 @@ template static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -151,7 +151,7 @@ static void argsort_f32_i32_cuda_bitonic(const float * x, const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = ncols_pad * sizeof(int); // FIXME: this limit could be raised by ~2-4x on Ampere or newer diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3139119a..4b1304c2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7111,7 +7111,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order)); - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection) + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024 + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) } for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {