]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda : fix argsort with 64k+ rows (llama/16849)
authorSigbjørn Skjæret <redacted>
Thu, 30 Oct 2025 07:56:28 +0000 (08:56 +0100)
committerGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 21:38:03 +0000 (23:38 +0200)
ggml/src/ggml-cuda/argsort.cu

index 6e7b90d42783f3a0d69c2eb4a9b2a7d4ec3236d0..3722cf3ab26ee7bc9e677636349f9c4bcd565e26 100644 (file)
@@ -87,7 +87,7 @@ template<ggml_sort_order order>
 static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
     // bitonic sort
     int col = threadIdx.x;
-    int row = blockIdx.y;
+    int row = blockIdx.x;
 
     if (col >= ncols_pad) {
         return;
@@ -151,7 +151,7 @@ static void argsort_f32_i32_cuda_bitonic(const float *   x,
     const int ncols_pad = next_power_of_2(ncols);
 
     const dim3 block_dims(ncols_pad, 1, 1);
-    const dim3 block_nums(1, nrows, 1);
+    const dim3 block_nums(nrows, 1, 1);
     const size_t shared_mem = ncols_pad * sizeof(int);
 
     // FIXME: this limit could be raised by ~2-4x on Ampere or newer