]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: fix top_k bug when there are ties in the input (llama/17659)
authorJeff Bolz <redacted>
Fri, 5 Dec 2025 21:03:19 +0000 (15:03 -0600)
committerGeorgi Gerganov <redacted>
Fri, 12 Dec 2025 15:53:19 +0000 (17:53 +0200)
* vulkan: Reduce temporary memory usage for TOP_K

- Compute row size for the temp buffer based on the output of the first pass.
- Update shader addressing math to use the output row size
- Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k"

For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer
from about 3.2MB to 500KB.

* vulkan: fix top_k bug when there are ties in the input

I noticed by inspection a bug in the vulkan top_k shader where if the least
value in the top_k appears multiple times we could end up writing those extra
copies out rather than some larger values (if the larger values are on higher
numbered threads).

I rewrote the test verification to handle this case, where the final index set
is not necessarily the same.

* Update tests/test-backend-ops.cpp

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp

index f57199b1adba79036c1e745e076e580c08feabd9..4a6ec1be77d5143ced2cd2367df76ac8517db9e3 100644 (file)
@@ -4013,7 +4013,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
                                   sizeof(int) * device->subgroup_size +
                                   2 * sizeof(int) +
-                                  (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
+                                  2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
             if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
                 nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
                 ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
index f794285ee15fabbd18078271e7baf8ec1e8d4f26..0b757f38e1868e473ed829b2e08487cd4d214e8c 100644 (file)
@@ -38,6 +38,7 @@ shared int counts[SUBGROUP_SIZE];
 shared int sh_min_idx;
 shared uint sh_total;
 shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
+shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];
 
 // Map float values to uint such that comparisons still work.
 // Positive values set the high bit, negative values are inverted.
@@ -156,25 +157,66 @@ void topk(const uint row) {
         // We need to compact these values to the start of the dst_row array.
         // Have each subgroup count how many items it'll store, so other
         // subgroups can compute their base offset.
-        bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
-        uvec4 b = subgroupBallot(top);
-        uint bit_count = subgroupBallotBitCount(b);
-        if ((tid % SUBGROUP_SIZE) == 0) {
-            offset_partials[tid / SUBGROUP_SIZE] = bit_count;
-        }
-        barrier();
+        // Values strictly greater than range_min must be stored. For values equal
+        // to range_min, there can be ties and it's possible we'll need to store
+        // an arbitrary subset of them.
+        // If total == p.k, have a fast path where we don't need to handle ties.
+        if (total == p.k) {
+            bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
+            uvec4 b = subgroupBallot(top);
+            uint bit_count = subgroupBallotBitCount(b);
+            if ((tid % SUBGROUP_SIZE) == 0) {
+                offset_partials[tid / SUBGROUP_SIZE] = bit_count;
+            }
+            barrier();
 
-        uint out_idx = 0;
-        [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
-            if (i < tid / SUBGROUP_SIZE) {
-                out_idx += offset_partials[i];
+            uint out_idx = 0;
+            [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
+                if (i < tid / SUBGROUP_SIZE) {
+                    out_idx += offset_partials[i];
+                }
             }
-        }
 
-        uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
-        if (top) {
-            // TODO: Copy directly to the output?
-            dst_row[out_idx + bit_count_ex] = v;
+            uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
+            if (top) {
+                // TODO: Copy directly to the output?
+                dst_row[out_idx + bit_count_ex] = v;
+            }
+        } else {
+            bool top = f2ui(intBitsToFloat(v.y)) > range_min;
+            bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
+            uvec4 b_top = subgroupBallot(top);
+            uvec4 b_eq_min = subgroupBallot(eq_min);
+            uint bit_count_top = subgroupBallotBitCount(b_top);
+            uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
+            if ((tid % SUBGROUP_SIZE) == 0) {
+                offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
+                eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
+            }
+            barrier();
+
+            uint out_idx = 0;
+            uint eq_min_base = 0;
+            uint eq_min_idx = 0;
+            [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
+                if (i < tid / SUBGROUP_SIZE) {
+                    out_idx += offset_partials[i];
+                    eq_min_idx += eq_min_partials[i];
+                }
+                eq_min_base += offset_partials[i];
+            }
+            // range_min values are stored at the end
+            eq_min_idx += eq_min_base;
+
+            uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
+            uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
+            if (top) {
+                // TODO: Copy directly to the output?
+                dst_row[out_idx + bit_count_ex_top] = v;
+            }
+            if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
+                dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
+            }
         }
 
         barrier();