]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: handle rope with large number of rows (llama/18306)
authorJeff Bolz <redacted>
Fri, 26 Dec 2025 15:53:46 +0000 (09:53 -0600)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 15:52:09 +0000 (17:52 +0200)
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp

index 1459b2608e5c1148c33c400c19cd66fd95480cd4..e7ce518fbac68eca67ca24881ccaab09b5730a81 100644 (file)
@@ -1192,6 +1192,7 @@ struct vk_op_diag_mask_push_constants {
 struct vk_op_rope_push_constants {
     uint32_t rope_mode;
     uint32_t ncols;
+    uint32_t nrows;
     uint32_t n_dims;
     float freq_scale;
     uint32_t p_delta_rows;
@@ -9090,10 +9091,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
             elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
         } break;
     case GGML_OP_DIAG_MASK_INF:
-    case GGML_OP_ROPE:
-    case GGML_OP_ROPE_BACK:
         elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
         break;
+    case GGML_OP_ROPE:
+    case GGML_OP_ROPE_BACK:
+        {
+            uint32_t nrows = (uint32_t)ggml_nrows(src0);
+            uint32_t z = 1;
+            if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
+                z = CEIL_DIV(nrows, 32768);
+                nrows = 32768;
+            }
+            elements = { nrows, (uint32_t)ne00, z };
+
+        } break;
     case GGML_OP_GET_ROWS:
         elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
         elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
@@ -10021,7 +10032,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
     uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
 
     vk_op_rope_push_constants rope {
-        (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
+        (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
         freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
         has_ff, (uint32_t)src0->ne[2], nb01, nb02,
         { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
index 7c1fb1cd22440ece5cb42e02866a65f79fa4415f..f7587468a81534ef1c20abba2b49536705500db0 100644 (file)
@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_multi(i0, i1, pc);
 }
index 68f00c180bb9ffeb171b7850e289634ea2187bad..acb8ed7815582255f6d26885f27e6ff76458f288 100644 (file)
@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_neox(i0, i1, pc);
 }
index 28a939ec6ad39cf8133cdd9d703f4f7f2153c175..0033cdb224f77d9e132ecb8fa85c45c96ad42fa0 100644 (file)
@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_norm(i0, i1, pc);
 }
index 82f39cee349d8d5f3e8158e56443c5bfa82b53de..939cf3c51cdbfc2efaf04f752bc0772f5f08345a 100644 (file)
@@ -6,6 +6,7 @@
 struct rope_params {
     uint rope_mode;
     uint ncols;
+    uint nrows;
     uint n_dims;
     float freq_scale;
     uint p_delta_rows;
index ea1e0fdb416887a34ff64294b80445d8c47113f1..d93800b5e7666488dc3369b1f5e8cdb49b79c861 100644 (file)
@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_vision(i0, i1, pc);
 }