]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Handle argsort with a large number of rows (#16851)
authorJeff Bolz <redacted>
Thu, 30 Oct 2025 06:27:41 +0000 (01:27 -0500)
committerGitHub <redacted>
Thu, 30 Oct 2025 06:27:41 +0000 (07:27 +0100)
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp

index 8a9f5980ea84317738a53853addff9183152478f..d0976519f263feb5b72e6d69658eab77c6ec580e 100644 (file)
@@ -1082,6 +1082,7 @@ struct vk_op_soft_max_push_constants {
 
 struct vk_op_argsort_push_constants {
     uint32_t ncols;
+    uint32_t nrows;
     int32_t order;
 };
 
@@ -8708,6 +8709,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
         break;
     case GGML_OP_ARGSORT:
         elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
+        elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
         break;
     case GGML_OP_IM2COL:
         {
@@ -9954,9 +9956,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
     int32_t * op_params = (int32_t *)dst->op_params;
 
     uint32_t ncols = src0->ne[0];
+    uint32_t nrows = ggml_nrows(src0);
 
     ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
         ncols,
+        nrows,
         op_params[0],
     }, dryrun);
 }
index c81b84452e7697377451210c125fde486d92dbb5..c4e68bc02370ac862a69aed68b277a7c60ab3126 100644 (file)
@@ -14,6 +14,7 @@ layout (binding = 1)          buffer D {int data_d[];};
 
 layout (push_constant) uniform parameter {
     uint ncols;
+    uint nrows;
     uint order;
 } p;
 
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
     dst_row[idx1] = tmp;
 }
 
-void argsort(bool needs_bounds_check) {
+void argsort(bool needs_bounds_check, const uint row) {
     // bitonic sort
     const int col = int(gl_LocalInvocationID.x);
-    const uint row = gl_WorkGroupID.y;
 
     const uint row_offset = row * p.ncols;
 
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
 
 void main() {
     if (p.ncols == BLOCK_SIZE) {
-        argsort(false);
+        uint row = gl_WorkGroupID.y;
+        while (row < p.nrows) {
+            argsort(false, row);
+            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+        }
     } else {
-        argsort(true);
+        uint row = gl_WorkGroupID.y;
+        while (row < p.nrows) {
+            argsort(true, row);
+            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+        }
     }
 }