]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: handle rope with large number of rows (llama/18306)
authorJeff Bolz <redacted>
Fri, 26 Dec 2025 15:53:46 +0000 (09:53 -0600)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 10:39:43 +0000 (12:39 +0200)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/rope_multi.comp
src/ggml-vulkan/vulkan-shaders/rope_neox.comp
src/ggml-vulkan/vulkan-shaders/rope_norm.comp
src/ggml-vulkan/vulkan-shaders/rope_params.glsl
src/ggml-vulkan/vulkan-shaders/rope_vision.comp
tests/test-backend-ops.cpp

index 1459b2608e5c1148c33c400c19cd66fd95480cd4..e7ce518fbac68eca67ca24881ccaab09b5730a81 100644 (file)
@@ -1192,6 +1192,7 @@ struct vk_op_diag_mask_push_constants {
 struct vk_op_rope_push_constants {
     uint32_t rope_mode;
     uint32_t ncols;
+    uint32_t nrows;
     uint32_t n_dims;
     float freq_scale;
     uint32_t p_delta_rows;
@@ -9090,10 +9091,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
             elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
         } break;
     case GGML_OP_DIAG_MASK_INF:
-    case GGML_OP_ROPE:
-    case GGML_OP_ROPE_BACK:
         elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
         break;
+    case GGML_OP_ROPE:
+    case GGML_OP_ROPE_BACK:
+        {
+            uint32_t nrows = (uint32_t)ggml_nrows(src0);
+            uint32_t z = 1;
+            if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
+                z = CEIL_DIV(nrows, 32768);
+                nrows = 32768;
+            }
+            elements = { nrows, (uint32_t)ne00, z };
+
+        } break;
     case GGML_OP_GET_ROWS:
         elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
         elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
@@ -10021,7 +10032,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
     uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
 
     vk_op_rope_push_constants rope {
-        (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
+        (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
         freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
         has_ff, (uint32_t)src0->ne[2], nb01, nb02,
         { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
index 7c1fb1cd22440ece5cb42e02866a65f79fa4415f..f7587468a81534ef1c20abba2b49536705500db0 100644 (file)
@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_multi(i0, i1, pc);
 }
index 68f00c180bb9ffeb171b7850e289634ea2187bad..acb8ed7815582255f6d26885f27e6ff76458f288 100644 (file)
@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_neox(i0, i1, pc);
 }
index 28a939ec6ad39cf8133cdd9d703f4f7f2153c175..0033cdb224f77d9e132ecb8fa85c45c96ad42fa0 100644 (file)
@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_norm(i0, i1, pc);
 }
index 82f39cee349d8d5f3e8158e56443c5bfa82b53de..939cf3c51cdbfc2efaf04f752bc0772f5f08345a 100644 (file)
@@ -6,6 +6,7 @@
 struct rope_params {
     uint rope_mode;
     uint ncols;
+    uint nrows;
     uint n_dims;
     float freq_scale;
     uint p_delta_rows;
index ea1e0fdb416887a34ff64294b80445d8c47113f1..d93800b5e7666488dc3369b1f5e8cdb49b79c861 100644 (file)
@@ -6,6 +6,9 @@
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
     // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (i1 >= pc.nrows) {
+        return;
+    }
     rope_vision(i0, i1, pc);
 }
index 6b65f6e1c75a2c74b65b7e969a5e6f154f313a34..a801455f16fc21036056275c400e9bef99cbb022 100644 (file)
@@ -7775,6 +7775,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                                     test_cases.emplace_back(new test_rope(type, {128,  40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B
                                     test_cases.emplace_back(new test_rope(type, {128,  52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B
                                     test_cases.emplace_back(new test_rope(type, {128,  64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B
+                                    test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
                                 }
 
                                 if (all) {
@@ -7789,6 +7790,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                                     test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
                                     test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
                                     test_cases.emplace_back(new test_rope(type, { 80,  32, 4, 1},  32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
+                                    test_cases.emplace_back(new test_rope(type, { 16, 16, 8192, 1},  16, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw));
                                 }
 
                                 if (all) {
@@ -7802,6 +7804,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                                     test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1},  32, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw));
                                     test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
                                     test_cases.emplace_back(new test_rope(type, {128,  16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
+                                    test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
                                 }
 
                                 test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)