]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: fix non-contig rope (llama/19299)
authorJeff Bolz <redacted>
Thu, 5 Feb 2026 07:38:59 +0000 (01:38 -0600)
committerGeorgi Gerganov <redacted>
Sat, 7 Feb 2026 08:37:38 +0000 (10:37 +0200)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/rms_norm.comp
src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
src/ggml-vulkan/vulkan-shaders/rope_multi.comp
src/ggml-vulkan/vulkan-shaders/rope_neox.comp
src/ggml-vulkan/vulkan-shaders/rope_norm.comp
src/ggml-vulkan/vulkan-shaders/rope_params.glsl
src/ggml-vulkan/vulkan-shaders/rope_vision.comp

index cb7fa2c9cbb2b3a6119a7d210ac016cbbc051c6f..af57685a37dda37d1d7c10fa4e496c4788ea4c66 100644 (file)
@@ -1263,25 +1263,30 @@ struct vk_op_diag_mask_push_constants {
 
 struct vk_op_rope_push_constants {
     uint32_t rope_mode;
-    uint32_t ncols;
     uint32_t nrows;
     uint32_t n_dims;
     float freq_scale;
-    uint32_t p_delta_rows;
     float freq_base;
     float ext_factor;
     float attn_factor;
     float corr_dims[2];
     float theta_scale;
     uint32_t has_ff;
-    uint32_t ne02;
-    uint32_t s1;
-    uint32_t s2;
     int32_t sections[4];
     uint32_t is_imrope;
     uint32_t is_back;
     uint32_t set_rows_stride;
+    uint32_t ne00;
+    uint32_t ne01;
+    uint32_t ne02;
+    uint32_t nb01;
+    uint32_t nb02;
+    uint32_t nb03;
+    uint32_t nb11;
+    uint32_t nb12;
+    uint32_t nb13;
 };
+static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
 
 // For fused rms_norm+mul+rope(+view+set_rows)
 struct vk_op_rms_norm_mul_rope_push_constants {
@@ -10405,12 +10410,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
 
     uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
     uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
+    uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
+
+    uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
+    uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
+    uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
 
     vk_op_rope_push_constants rope {
-        (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
-        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
-        has_ff, (uint32_t)src0->ne[2], nb01, nb02,
+        (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
+        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
         { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
+
+        (uint32_t)src0->ne[0],
+        (uint32_t)src0->ne[1],
+        (uint32_t)src0->ne[2],
+        nb01, nb02, nb03,
+        nb11, nb12, nb13,
     };
 
     return rope;
@@ -14798,6 +14813,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_REPEAT_BACK:
             return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ROPE:
+            return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_ROPE_BACK:
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
index 9d6d366542741ce43da1192c6f2de18254db6470..55b89f19a7a894c65dc7fd41eaa0b592b2a6e71a 100644 (file)
@@ -112,12 +112,11 @@ void rms_norm(uint num_iters) {
 #if RMS_NORM_ROPE_FUSION
     barrier();
     rope_params rp = p.rope;
-    uint rope_row = (samp*nchannels + channel)*nrows + row;
     for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
         if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
-            rope_neox(t, rope_row, rp);
+            rope_neox(t, row, channel, samp, rp);
         } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
-            rope_norm(t, rope_row, rp);
+            rope_norm(t, row, channel, samp, rp);
         }
     }
 #endif
index aacec984696f01338591d02578161fe1bc9dfded..2e53459909d7f33a000e4a5174bb762a79085cff 100644 (file)
@@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) {
     return 1.0f - min(1.0f, max(0.0f, y));
 }
 
-uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
+uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {
 #if RMS_NORM_ROPE_FUSION
     // Per-row offset in shared memory
     const uint ix = i0;
 #else
-    const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
+    const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
 #endif
     return ix;
 }
@@ -34,26 +34,19 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
     sin_theta = sin(theta) * mscale;
 }
 
-void rope_norm(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-
-    if (i0 >= ne0) {
+void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    uint idst = i1*ne0 + i0;
-    const uint ix = rope_a_coord(i0, i01, i02, p);
+    uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0, i1, i2, i3, p);
 
     // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
     if (p.set_rows_stride != 0) {
-        idst = i01*ne0 + i0;
-        idst += rope_data_i[i02].x * p.set_rows_stride;
+        idst = i1*p.nb11 + i0;
+        idst += rope_data_i[i2].x * p.set_rows_stride;
     }
 
     if (i0 >= p.n_dims) {
@@ -63,7 +56,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
         return;
     }
 
-    const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
+    const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
 
     const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
 
@@ -77,25 +70,19 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
     rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
 }
 
-void rope_neox(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-
-    if (i0 >= ne0) {
+void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    uint idst = i1*ne0 + i0/2;
-    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+    uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
 
     // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
     if (p.set_rows_stride != 0) {
-        idst = i01*ne0 + i0/2;
-        idst += rope_data_i[i02].x * p.set_rows_stride;
+        idst = i1*p.nb11 + i0/2;
+        idst += rope_data_i[i2].x * p.set_rows_stride;
     }
 
     if (i0 >= p.n_dims) {
@@ -105,7 +92,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
         return;
     }
 
-    const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
+    const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
 
     const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
 
@@ -120,26 +107,19 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
 }
 
 
-void rope_multi(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-    uint ne2 = p.ne02;
-
-    if (i0 >= ne0) {
+void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    uint idst = i1*ne0 + i0/2;
-    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+    uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
 
     // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
     if (p.set_rows_stride != 0) {
-        idst = i01*ne0 + i0/2;
-        idst += rope_data_i[i02].x * p.set_rows_stride;
+        idst = i1*p.nb11 + i0/2;
+        idst += rope_data_i[i2].x * p.set_rows_stride;
     }
 
     if (i0 >= p.n_dims) {
@@ -156,26 +136,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
     float theta_base = 0.0;
     if (p.is_imrope != 0) {
         if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
-            theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
         } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
-            theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
         } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
-            theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
         } else {
-            theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
         }
     } else {
         if (sector < p.sections[0]) {
-            theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
         }
         else if (sector >= p.sections[0] && sector < sec_w) {
-            theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
         }
         else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
-            theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
         }
         else if (sector >= sec_w + p.sections[2]) {
-            theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
         }
     }
 
@@ -191,20 +171,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
     rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
 }
 
-void rope_vision(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-    uint ne2 = p.ne02;
-
-    if (i0 >= ne0) {
+void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    const uint idst = i1*ne0 + i0/2;
-    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+    const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
 
     const int sect_dims = p.sections[0] + p.sections[1];
     const int sec_w = p.sections[1] + p.sections[0];
@@ -213,11 +186,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) {
     float theta_base = 0.0;
     if (sector < p.sections[0]) {
         const uint p0 = sector;
-        theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
+        theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);
     }
     else if (sector >= p.sections[0] && sector < sec_w) {
         const uint p0 = sector - p.sections[0];
-        theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
+        theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);
     }
 
     const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
index f7587468a81534ef1c20abba2b49536705500db0..1528fbeeaecfb04c909ebeec1f5a5512bff344cf 100644 (file)
@@ -5,10 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
-    if (i1 >= pc.nrows) {
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
         return;
     }
-    rope_multi(i0, i1, pc);
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_multi(i0, i1, i2, i3, pc);
 }
index acb8ed7815582255f6d26885f27e6ff76458f288..ad0896095db62c734ec79c756f22305f36245300 100644 (file)
@@ -5,10 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
-    if (i1 >= pc.nrows) {
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
         return;
     }
-    rope_neox(i0, i1, pc);
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_neox(i0, i1, i2, i3, pc);
 }
index 0033cdb224f77d9e132ecb8fa85c45c96ad42fa0..11220817df09311a3fb40027cc44874f75b8f8d9 100644 (file)
@@ -5,10 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
-    if (i1 >= pc.nrows) {
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
         return;
     }
-    rope_norm(i0, i1, pc);
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_norm(i0, i1, i2, i3, pc);
 }
index 939cf3c51cdbfc2efaf04f752bc0772f5f08345a..ec6ceaca9bda4a39017fa2feaa1e0c3b3cb94ece 100644 (file)
@@ -5,24 +5,29 @@
 
 struct rope_params {
     uint rope_mode;
-    uint ncols;
     uint nrows;
     uint n_dims;
     float freq_scale;
-    uint p_delta_rows;
     float freq_base;
     float ext_factor;
     float attn_factor;
     float corr_dims[2];
     float theta_scale;
     uint has_ff;
-    uint ne02;
-    uint nb01;
-    uint nb02;
     int sections[4];
     uint is_imrope;
     uint is_back;
     uint set_rows_stride;
+
+    uint ne00;
+    uint ne01;
+    uint ne02;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    uint nb11;
+    uint nb12;
+    uint nb13;
 };
 
 #endif // !defined(GGML_ROPE_PARAMS)
index d93800b5e7666488dc3369b1f5e8cdb49b79c861..ca71efb2f5586e715e1f9b5040b459fa3594e6d6 100644 (file)
@@ -5,10 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
-    if (i1 >= pc.nrows) {
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
         return;
     }
-    rope_vision(i0, i1, pc);
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_vision(i0, i1, i2, i3, pc);
 }