]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : fix argsort with 64k+ rows (#16849)
authorSigbjørn Skjæret <redacted>
Thu, 30 Oct 2025 07:56:28 +0000 (08:56 +0100)
committerGitHub <redacted>
Thu, 30 Oct 2025 07:56:28 +0000 (08:56 +0100)
ggml/src/ggml-cuda/argsort.cu
tests/test-backend-ops.cpp

index 6e7b90d42783f3a0d69c2eb4a9b2a7d4ec3236d0..3722cf3ab26ee7bc9e677636349f9c4bcd565e26 100644 (file)
@@ -87,7 +87,7 @@ template<ggml_sort_order order>
 static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
     // bitonic sort
     int col = threadIdx.x;
-    int row = blockIdx.y;
+    int row = blockIdx.x;
 
     if (col >= ncols_pad) {
         return;
@@ -151,7 +151,7 @@ static void argsort_f32_i32_cuda_bitonic(const float *   x,
     const int ncols_pad = next_power_of_2(ncols);
 
     const dim3 block_dims(ncols_pad, 1, 1);
-    const dim3 block_nums(1, nrows, 1);
+    const dim3 block_nums(nrows, 1, 1);
     const size_t shared_mem = ncols_pad * sizeof(int);
 
     // FIXME: this limit could be raised by ~2-4x on Ampere or newer
index 3139119af7ac03be0d220edc7c019a89168f762d..4b1304c2b891cc549656759d4c8753d6148a3b84 100644 (file)
@@ -7111,7 +7111,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
-        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection)
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
     }
 
     for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {