]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: fuse rms_norm + mul + rope (+ view + set_rows) (#16977)
authorJeff Bolz <redacted>
Sat, 8 Nov 2025 07:52:15 +0000 (01:52 -0600)
committerGitHub <redacted>
Sat, 8 Nov 2025 07:52:15 +0000 (08:52 +0100)
This change combines the rms_norm+mul and rope+view+set_rows fusions to
allow fusing the whole sequence together. This comes up in Qwen3, Bailing,
and some other models.

12 files changed:
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp

index 2646e80be7582258eeda13a92d7297fe7deead0b..9c2aeb57f0003f082dccd2fcbcc292ac3c7105d3 100644 (file)
@@ -466,6 +466,14 @@ static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_ed
     { 2, 0, 1 }, // set_rows->src[0] == view
 };
 
+static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_view_set_rows_edges {
+    { 1, 0, 0 }, // mul->src[0]      == rms
+    { 2, 0, 1 }, // rope->src[0]     == mul
+    { 3, 0, 2 }, // view->src[0]     == rope
+    { 4, 0, 3 }, // set_rows->src[0] == view
+};
+
+
 struct vk_device_struct {
     std::recursive_mutex mutex;
 
@@ -617,6 +625,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_rms_norm_mul_f32;
     vk_pipeline pipeline_rms_norm_partials_f32;
     vk_pipeline pipeline_rms_norm_mul_partials_f32;
+    vk_pipeline pipeline_rms_norm_mul_rope_f32_f32;
+    vk_pipeline pipeline_rms_norm_mul_rope_f32_f16;
     vk_pipeline pipeline_rms_norm_back_f32;
     vk_pipeline pipeline_l2_norm_f32;
 
@@ -1060,6 +1070,7 @@ struct vk_op_diag_mask_push_constants {
 };
 
 struct vk_op_rope_push_constants {
+    uint32_t rope_mode;
     uint32_t ncols;
     uint32_t n_dims;
     float freq_scale;
@@ -1079,6 +1090,12 @@ struct vk_op_rope_push_constants {
     uint32_t set_rows_stride;
 };
 
+// For fused rms_norm+mul+rope(+view+set_rows)
+struct vk_op_rms_norm_mul_rope_push_constants {
+    vk_op_binary_push_constants bin;
+    vk_op_rope_push_constants rope;
+};
+
 struct vk_op_soft_max_push_constants {
     uint32_t KX;
     uint32_t KY;
@@ -3557,6 +3574,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
 
+    if (device->float_controls_rte_fp16 &&
+        sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
+        ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
+    }
+
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
@@ -9590,21 +9613,149 @@ static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const g
     return num_bytes;
 }
 
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params) {
+static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) {
+    const int n_dims        = ((const int32_t *) dst->op_params)[1];
+    const int mode          = ((const int32_t *) dst->op_params)[2];
+    // const int n_ctx         = ((const int32_t *) dst->op_params)[3];
+    const int n_ctx_orig    = ((const int32_t *) dst->op_params)[4];
+    const float freq_base   = ((const float *)   dst->op_params)[5];
+    const float freq_scale  = ((const float *)   dst->op_params)[6];
+    const float ext_factor  = ((const float *)   dst->op_params)[7];
+    const float attn_factor = ((const float *)   dst->op_params)[8];
+    const float beta_fast   = ((const float *)   dst->op_params)[9];
+    const float beta_slow   = ((const float *)   dst->op_params)[10];
+    int sections[4] {};
+    if (mode & GGML_ROPE_TYPE_MROPE) {
+        memcpy(sections, (const int32_t *) dst->op_params + 11, sizeof(int)*4);
+    }
+
+    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
+    uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
+
+    vk_op_rope_push_constants rope {
+        (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
+        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
+        has_ff, (uint32_t)src0->ne[2], nb01, nb02,
+        { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
+    };
+
+    return rope;
+}
+
+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) {
+    ggml_tensor * dst;
+    const ggml_tensor * src0;
+    const ggml_tensor * src1;
+
+    if (ctx->num_additional_fused_ops > 0) {
+        // fused rms_norm + mul
+        ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+        ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0];
+        dst = mul;
+        src0 = cgraph->nodes[node_idx]->src[0];
+        src1 = other_src;
+    } else {
+        dst = cgraph->nodes[node_idx];
+        src0 = src1 = dst->src[0];
+    }
+
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
 
     uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
 
-    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
+    vk_op_binary_push_constants bin {
         (uint32_t)ggml_nelements(src0),
         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
         op_params[0], 0.0f, (int32_t)param3,
-    });
+    };
+
+    // more than one fused op means rms_norm+mul+rope
+    if (ctx->num_additional_fused_ops > 1) {
+        static constexpr uint32_t max_tensors = 7;
+        const ggml_tensor *tensors[max_tensors] {};
+
+        ggml_tensor *rms = cgraph->nodes[node_idx + 0];
+        ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+        ggml_tensor *rope = cgraph->nodes[node_idx + 2];
+
+        ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
+
+        bool do_set_rows = ctx->num_additional_fused_ops == 4;
+
+        tensors[0] = rms->src[0];
+        tensors[1] = other_src;
+        tensors[2] = mul;
+        tensors[3] = rope->src[1]; // pos
+        tensors[4] = rope->src[2]; // ff
+        tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst
+        tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr;
+        const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(tensors[5]->type) : 0;
+
+        vk_op_rms_norm_mul_rope_push_constants pc;
+        pc.bin = bin;
+        pc.rope = ggml_vk_make_rope_constants(rope, rope->src[0], tensors[4] != nullptr, false, set_rows_stride);
+
+        vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32;
+
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
+
+        ggml_backend_vk_buffer_context * buf_ctx[max_tensors];
+        vk_buffer buf[max_tensors];
+        size_t offset[max_tensors];
+        bool uma[max_tensors];
+
+        for (uint32_t i = 0; i < max_tensors; ++i) {
+            if (!tensors[i]) {
+                // If any remaining descriptors are unused, just point them at src[0]
+                buf[i] = buf[0];
+                offset[i] = 0;
+                continue;
+            }
+            buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
+            buf[i] = nullptr;
+            offset[i] = 0;
+            uma[i] = false;
+
+            if (ctx->device->uma) {
+                ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
+                uma[i] = buf[i] != nullptr;
+            }
+            if (!uma[i]) {
+                buf[i] = buf_ctx[i]->dev_buffer;
+                offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
+            }
+            GGML_ASSERT(buf[i] != nullptr);
+        }
+
+        std::array<uint32_t, 3> elements;
+        elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };
+
+        static_assert(max_tensors == 7);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+            {
+                ggml_vk_subbuffer(ctx, buf[0], offset[0]),
+                ggml_vk_subbuffer(ctx, buf[1], offset[1]),
+                ggml_vk_subbuffer(ctx, buf[2], offset[2]),
+                ggml_vk_subbuffer(ctx, buf[3], offset[3]),
+                ggml_vk_subbuffer(ctx, buf[4], offset[4]),
+                ggml_vk_subbuffer(ctx, buf[5], offset[5]),
+                ggml_vk_subbuffer(ctx, buf[6], offset[6]),
+            }, pc, elements);
+    } else {
+        ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, std::move(bin));
+    }
 
     if (ctx->do_add_rms_partials_offset_calculation) {
         ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
@@ -9758,9 +9909,6 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
     // const int n_ctx         = ((int32_t *) dst->op_params)[3];
     const int n_ctx_orig    = ((int32_t *) dst->op_params)[4];
     const float freq_base   = ((float *)   dst->op_params)[5];
-    const float freq_scale  = ((float *)   dst->op_params)[6];
-    const float ext_factor  = ((float *)   dst->op_params)[7];
-    const float attn_factor = ((float *)   dst->op_params)[8];
     const float beta_fast   = ((float *)   dst->op_params)[9];
     const float beta_slow   = ((float *)   dst->op_params)[10];
     int sections[4] {};
@@ -9768,16 +9916,9 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
         memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
     }
 
-    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
-
     float corr_dims[2];
     ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
-    uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
-    uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
-
     uint32_t set_rows_stride = 0;
     // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
     // and overrides the dst and sets src3=row_indices
@@ -9787,12 +9928,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
         dst = cgraph->nodes[node_idx + 2];
     }
 
-    ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, {
-        (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
-        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
-        src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
-        { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
-    });
+    ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE,
+        ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
 }
 
 static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -11307,6 +11444,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         if (n->op == GGML_OP_GLU) {
             std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
         }
+        if (n->op == GGML_OP_ROPE) {
+            const int mode = ((const int32_t *) n->op_params)[2];
+            std::cerr << " rope mode: " << mode;
+        }
         std::cerr << std::endl;
     }
 #endif
@@ -11414,14 +11555,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_RMS_NORM:
-        if (ctx->num_additional_fused_ops > 0) {
-            // fused rms_norm + mul
-            ggml_tensor *mul = cgraph->nodes[node_idx + 1];
-            ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
-            ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params);
-        } else {
-            ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params);
-        }
+        ggml_vk_rms_norm(ctx, compute_ctx, cgraph, node_idx, (float *)node->op_params);
         break;
     case GGML_OP_RMS_NORM_BACK:
         ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node);
@@ -12407,6 +12541,70 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
     return true;
 }
 
+// Check whether the tensors overlap in memory but are not equal.
+// Fusions can potenitally overwrite src tensors in ways that are not prevented
+// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
+// to overlap if they are exactly equal.
+// XXX TODO this check is probably missing from several fusion optimizations.
+static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
+    ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
+    vk_buffer a_buf = a_buf_ctx->dev_buffer;
+    ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
+    vk_buffer b_buf = b_buf_ctx->dev_buffer;
+    if (a_buf == b_buf) {
+        auto a_base = vk_tensor_offset(a) + a->view_offs;
+        auto a_size = ggml_nbytes(a);
+        auto b_base = vk_tensor_offset(b) + b->view_offs;
+        auto b_size = ggml_nbytes(b);
+
+        if (a_base == b_base && a_size == b_size) {
+            return false;
+        }
+
+        if ((b_base <= a_base && a_base < b_base + b_size) ||
+            (a_base <= b_base && b_base < a_base + a_size)) {
+            return true;
+        }
+    }
+    return false;
+}
+
+static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
+                                               int node_idx) {
+    GGML_UNUSED(ctx);
+    const ggml_tensor *rms = cgraph->nodes[node_idx + 0];
+    const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+    const ggml_tensor *rope = cgraph->nodes[node_idx + 2];
+
+    const int mode = ((const int32_t *) rope->op_params)[2];
+
+    // noncontig tensors aren't tested, and don't seem common in practice
+    if (!ggml_is_contiguous(rms) ||
+        !ggml_is_contiguous(mul) ||
+        !ggml_is_contiguous(rope)) {
+        return false;
+    }
+
+    // only norm/neox are handled in the shader
+    if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) {
+        return false;
+    }
+
+    // shared memory size for passing data from mul->rope
+    if (mul->ne[0] > 1024) {
+        return false;
+    }
+
+    // must not overwrite srcs in a way that's not elementwise
+    ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
+    if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
+        ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
+        return false;
+    }
+
+    return true;
+}
+
 static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
 
     const ggml_tensor *first_node = cgraph->nodes[node_idx];
@@ -12552,12 +12750,20 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
             if (num_adds) {
                 ctx->num_additional_fused_ops = num_adds - 1;
-            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
-                ctx->num_additional_fused_ops = 1;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
                 ctx->num_additional_fused_ops = 1;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
                 ctx->num_additional_fused_ops = 1;
+            } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
+                       ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
+                       ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
+                       ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
+                ctx->num_additional_fused_ops = 4;
+            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
+                       ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
+                ctx->num_additional_fused_ops = 2;
+            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+                ctx->num_additional_fused_ops = 1;
             } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
                        ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
                        ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
@@ -12790,14 +12996,34 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
             }
             if (ok) {
                 current_set.push_back(j);
+
+                int rope_idx = j;
+
+                // When we've found RMS_NORM + MUL, try to find a ROPE that uses it
+                if (j > 0 &&
+                    graph->nodes[j]->op == GGML_OP_MUL &&
+                    graph->nodes[j-1]->op == GGML_OP_RMS_NORM) {
+                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
+                        if (graph->nodes[k]->op == GGML_OP_ROPE &&
+                            graph->nodes[k]->src[0] == graph->nodes[j] &&
+                            // Check that other srcs are already valid
+                            graph->nodes[k]->src[1]->op == GGML_OP_NONE &&
+                            (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) {
+                            rope_idx = k;
+                            current_set.push_back(rope_idx);
+                            used[rope_idx] = true;
+                            break;
+                        }
+                    }
+                }
                 // Look for ROPE + VIEW + SET_ROWS and make them consecutive
-                if (graph->nodes[j]->op == GGML_OP_ROPE) {
+                if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) {
                     int view_idx = -1;
                     int set_rows_idx = -1;
-                    for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) {
+                    for (int k = rope_idx+1; k < std::min(rope_idx + 10, graph->n_nodes); ++k) {
                         if (view_idx == -1 &&
                             graph->nodes[k]->op == GGML_OP_VIEW &&
-                            graph->nodes[k]->src[0] == graph->nodes[j]) {
+                            graph->nodes[k]->src[0] == graph->nodes[rope_idx]) {
                             view_idx = k;
                             continue;
                         }
index 99595fc688c089dcecc46f9097c28c60c4a69d74..c1ad5172562d495bc9349e7117a2b0396b075bcf 100644 (file)
@@ -3,6 +3,9 @@
 
 #include "rte.glsl"
 #include "utils.glsl"
+#if RMS_NORM_ROPE_FUSION
+#include "rope_params.glsl"
+#endif
 
 layout (push_constant) uniform parameter
 {
@@ -12,11 +15,16 @@ layout (push_constant) uniform parameter
     uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
     uint misalign_offsets;
     float param1; float param2; int param3;
+#if RMS_NORM_ROPE_FUSION
+    rope_params rope;
+#endif
 } p;
 
+#if !RMS_NORM_ROPE_FUSION
 layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+#endif
 
 // true if src0/src1 are the same shape and the indices can be reused without additional modulus
 layout(constant_id = 0) const bool norepeat = false;
index d5b211ffaa7bb6142f1309264ac7ad31ddd00bf7..3a47949d5a657fe750c02ba690d1cd50b766efe4 100644 (file)
@@ -3,6 +3,32 @@
 #include "generic_binary_head.glsl"
 #include "types.glsl"
 
+#if RMS_NORM_ROPE_FUSION
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+
+// data is passed from rms_norm -> rope through shared memory.
+// rms_norm calls this data_d, rope calls this rope_data_a.
+// Binding 2 is not used
+shared FLOAT_TYPE rope_data_a[1024];
+#define data_d rope_data_a
+
+layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];};
+layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];};
+layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];};
+layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows
+
+#include "rope_params.glsl"
+#include "rope_funcs.glsl"
+
+#define GGML_ROPE_TYPE_NORMAL 0
+#define GGML_ROPE_TYPE_NEOX   2
+#define GGML_ROPE_TYPE_MROPE  8
+#define GGML_ROPE_TYPE_VISION 24
+
+#endif
+
 #extension GL_EXT_control_flow_attributes : enable
 #define BLOCK_SIZE 512
 
@@ -28,8 +54,12 @@ void rms_norm(uint num_iters) {
 
     uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
     uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
+#if RMS_NORM_ROPE_FUSION
+    // Per-row offset in shared memory
+    uint32_t d_offset = 0;
+#else
     uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
-
+#endif
     FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
 
     [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
@@ -79,6 +109,18 @@ void rms_norm(uint num_iters) {
             data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
         }
     }
+#if RMS_NORM_ROPE_FUSION
+    barrier();
+    rope_params rp = p.rope;
+    uint rope_row = (samp*nchannels + channel)*nrows + row;
+    for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
+        if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
+            rope_neox(t, rope_row, rp);
+        } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
+            rope_norm(t, rope_row, rp);
+        }
+    }
+#endif
 }
 
 void main() {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
new file mode 100644 (file)
index 0000000..9726b72
--- /dev/null
@@ -0,0 +1,227 @@
+
+float rope_yarn_ramp(const float low, const float high, const uint i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
+#if RMS_NORM_ROPE_FUSION
+    // Per-row offset in shared memory
+    const uint ix = i0;
+#else
+    const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
+#endif
+    return ix;
+}
+
+void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) {
+    float mscale = p.attn_factor;
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = p.freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (p.ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
+    }
+    // Backprogagation uses inverted rotation
+    if (p.is_back != 0) {
+        theta = -theta;
+    }
+    cos_theta = cos(theta) * mscale;
+    sin_theta = sin(theta) * mscale;
+}
+
+void rope_norm(const uint i0, const uint i1, rope_params p) {
+    uint ne0 = p.ncols;
+    uint ne1 = p.p_delta_rows;
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    // i1 is actually i2*nb2+i1, but the rows are contiguous
+    const uint i01 = i1 % ne1;
+    const uint i02 = i1 / ne1;
+
+    uint idst = i1*ne0 + i0;
+    const uint ix = rope_a_coord(i0, i01, i02, p);
+
+    // Fusion optimization: ROPE + VIEW + SET_ROWS..
+    // The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
+    if (p.set_rows_stride != 0) {
+        idst = i01*ne0 + i0;
+        idst += rope_data_i[i02].x * p.set_rows_stride;
+    }
+
+    if (i0 >= p.n_dims) {
+        rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]);
+        rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]);
+
+        return;
+    }
+
+    const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
+
+    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
+
+    const float x0 = float(rope_data_a[ix + 0]);
+    const float x1 = float(rope_data_a[ix + 1]);
+
+    rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
+    rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
+}
+
+void rope_neox(const uint i0, const uint i1, rope_params p) {
+    uint ne0 = p.ncols;
+    uint ne1 = p.p_delta_rows;
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const uint i01 = i1 % ne1;
+    const uint i02 = i1 / ne1;
+
+    uint idst = i1*ne0 + i0/2;
+    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+
+    // Fusion optimization: ROPE + VIEW + SET_ROWS..
+    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
+    if (p.set_rows_stride != 0) {
+        idst = i01*ne0 + i0/2;
+        idst += rope_data_i[i02].x * p.set_rows_stride;
+    }
+
+    if (i0 >= p.n_dims) {
+        rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
+        rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
+
+        return;
+    }
+
+    const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
+
+    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
+
+    const float x0 = float(rope_data_a[ix + 0]);
+    const float x1 = float(rope_data_a[ix + p.n_dims/2]);
+
+    rope_data_d[idst + 0]          = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
+    rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
+}
+
+
+void rope_multi(const uint i0, const uint i1, rope_params p) {
+    uint ne0 = p.ncols;
+    uint ne1 = p.p_delta_rows;
+    uint ne2 = p.ne02;
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const uint i01 = i1 % ne1;
+    const uint i02 = i1 / ne1;
+
+    const uint idst = i1*ne0 + i0/2;
+    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+
+    if (i0 >= p.n_dims) {
+        rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
+        rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
+
+        return;
+    }
+
+    const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
+    const int sec_w = p.sections[1] + p.sections[0];
+    const uint sector = (i0 / 2) % sect_dims;
+
+    float theta_base = 0.0;
+    if (p.is_imrope != 0) {
+        if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
+            theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+        } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
+            theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+        } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
+            theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+        } else {
+            theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+        }
+    } else {
+        if (sector < p.sections[0]) {
+            theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+        }
+        else if (sector >= p.sections[0] && sector < sec_w) {
+            theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+        }
+        else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
+            theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+        }
+        else if (sector >= sec_w + p.sections[2]) {
+            theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+        }
+    }
+
+    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
+
+    const float x0 = float(rope_data_a[ix + 0]);
+    const float x1 = float(rope_data_a[ix + p.n_dims/2]);
+
+    rope_data_d[idst + 0]          = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
+    rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
+}
+
+void rope_vision(const uint i0, const uint i1, rope_params p) {
+    uint ne0 = p.ncols;
+    uint ne1 = p.p_delta_rows;
+    uint ne2 = p.ne02;
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const uint i01 = i1 % ne1;
+    const uint i02 = i1 / ne1;
+
+    const uint idst = i1*ne0 + i0/2;
+    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+
+    const int sect_dims = p.sections[0] + p.sections[1];
+    const int sec_w = p.sections[1] + p.sections[0];
+    const uint sector = (i0 / 2) % sect_dims;
+
+    float theta_base = 0.0;
+    if (sector < p.sections[0]) {
+        const uint p0 = sector;
+        theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
+    }
+    else if (sector >= p.sections[0] && sector < sec_w) {
+        const uint p0 = sector - p.sections[0];
+        theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
+    }
+
+    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
+
+    const float x0 = float(rope_data_a[ix + 0]);
+    const float x1 = float(rope_data_a[ix + p.n_dims]);
+
+    rope_data_d[idst + 0]        = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
+    rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
+}
+
index fa2bb33394cb2df081c31d1723cce3a23e212740..d9b4d4c03f34fc04d0243da8bfadf05637c59ac9 100644 (file)
@@ -3,56 +3,18 @@
 #extension GL_EXT_shader_16bit_storage : require
 
 #include "rte.glsl"
+#include "rope_params.glsl"
 
 layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
 
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer Y {int data_pos[];};
-layout (binding = 2) readonly buffer Z {float data_ff[];};
-layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
-layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
+layout (binding = 0) readonly buffer X {A_TYPE rope_data_a[];};
+layout (binding = 1) readonly buffer Y {int rope_data_pos[];};
+layout (binding = 2) readonly buffer Z {float rope_data_ff[];};
+layout (binding = 3) writeonly buffer D {ROPE_D_TYPE rope_data_d[];};
+layout (binding = 4) readonly buffer I {uvec2 rope_data_i[];}; // indices for set_rows
 
-layout (push_constant) uniform parameter {
-    uint ncols;
-    uint n_dims;
-    float freq_scale;
-    uint p_delta_rows;
-    float freq_base;
-    float ext_factor;
-    float attn_factor;
-    float corr_dims[2];
-    float theta_scale;
-    uint has_ff;
-    uint ne02;
-    uint s1;
-    uint s2;
-    int sections[4];
-    uint is_imrope;
-    uint is_back;
-    uint set_rows_stride;
-} p;
-
-float rope_yarn_ramp(const float low, const float high, const uint i0) {
-    const float y = (i0 / 2 - low) / max(0.001f, high - low);
-    return 1.0f - min(1.0f, max(0.0f, y));
-}
 
-void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
-    float mscale = p.attn_factor;
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = p.freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (p.ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+layout (push_constant) uniform parameter {
+    rope_params pc;
+};
 
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
-    }
-    // Backprogagation uses inverted rotation
-    if (p.is_back != 0) {
-        theta = -theta;
-    }
-    cos_theta = cos(theta) * mscale;
-    sin_theta = sin(theta) * mscale;
-}
index 54aabcf22283893470cd75c56fda2245f57a7d24..7c1fb1cd22440ece5cb42e02866a65f79fa4415f 100644 (file)
@@ -1,70 +1,11 @@
 #version 450
 
 #include "rope_head.glsl"
+#include "rope_funcs.glsl"
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-    uint ne2 = p.ne02;
-
-    if (i0 >= ne0) {
-        return;
-    }
-
-    const uint row_dst = gl_GlobalInvocationID.x;
-
-    const uint row_x     = row_dst % ne1;
-    const uint channel_x = row_dst / ne1;
-
-    const uint idst = row_dst*ne0 + i0/2;
-    const uint ix   = channel_x*p.s2 + row_x*p.s1 + i0/2;
-
-    if (i0 >= p.n_dims) {
-        data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
-        data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
-
-        return;
-    }
-
-    const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
-    const int sec_w = p.sections[1] + p.sections[0];
-    const uint sector = (i0 / 2) % sect_dims;
-
-    float theta_base = 0.0;
-    if (p.is_imrope != 0) {
-        if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
-            theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
-        } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
-            theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
-        } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
-            theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
-        } else {
-            theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
-        }
-    } else {
-        if (sector < p.sections[0]) {
-            theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
-        }
-        else if (sector >= p.sections[0] && sector < sec_w) {
-            theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
-        }
-        else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
-            theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
-        }
-        else if (sector >= sec_w + p.sections[2]) {
-            theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
-        }
-    }
-
-    const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
-
-    const float x0 = float(data_a[ix + 0]);
-    const float x1 = float(data_a[ix + p.n_dims/2]);
-
-    data_d[idst + 0]          = D_TYPE(x0*cos_theta - x1*sin_theta);
-    data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
+    // i1 is actually i2*nb2+i1, but the rows are contiguous
+    const uint i1 = gl_GlobalInvocationID.x;
+    rope_multi(i0, i1, pc);
 }
index 9f4538155a05ccb73dcc8c5414461d5301ce1e16..68f00c180bb9ffeb171b7850e289634ea2187bad 100644 (file)
@@ -1,48 +1,11 @@
 #version 450
 
 #include "rope_head.glsl"
+#include "rope_funcs.glsl"
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-
-    if (i0 >= ne0) {
-        return;
-    }
-
-    const uint row_dst = gl_GlobalInvocationID.x;
-
-    const uint row_x     = row_dst % ne1;
-    const uint channel_x = row_dst / ne1;
-
-    uint idst = row_dst*ne0 + i0/2;
-    const uint ix   = channel_x*p.s2 + row_x*p.s1 + i0/2;
-
-    // Fusion optimization: ROPE + VIEW + SET_ROWS..
-    // The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
-    if (p.set_rows_stride != 0) {
-        idst = row_x*ne0 + i0/2;
-        idst += data_i[channel_x].x * p.set_rows_stride;
-    }
-
-    if (i0 >= p.n_dims) {
-        data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
-        data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
-
-        return;
-    }
-
-    const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
-
-    const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
-
-    const float x0 = float(data_a[ix + 0]);
-    const float x1 = float(data_a[ix + p.n_dims/2]);
-
-    data_d[idst + 0]          = D_TYPE(x0*cos_theta - x1*sin_theta);
-    data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
+    // i1 is actually i2*nb2+i1, but the rows are contiguous
+    const uint i1 = gl_GlobalInvocationID.x;
+    rope_neox(i0, i1, pc);
 }
index f4209ed9582aa5081d31648929f1e9028301e267..28a939ec6ad39cf8133cdd9d703f4f7f2153c175 100644 (file)
@@ -1,48 +1,11 @@
 #version 450
 
 #include "rope_head.glsl"
+#include "rope_funcs.glsl"
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-
-    if (i0 >= ne0) {
-        return;
-    }
-
-    const uint row_dst = gl_GlobalInvocationID.x;
-
-    const uint row_x     = row_dst % ne1;
-    const uint channel_x = row_dst / ne1;
-
-    uint idst = row_dst*ne0 + i0;
-    const uint ix   = channel_x*p.s2 + row_x*p.s1 + i0;
-
-    // Fusion optimization: ROPE + VIEW + SET_ROWS..
-    // The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
-    if (p.set_rows_stride != 0) {
-        idst = row_x*ne0 + i0;
-        idst += data_i[channel_x].x * p.set_rows_stride;
-    }
-
-    if (i0 >= p.n_dims) {
-        data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
-        data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
-
-        return;
-    }
-
-    const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
-
-    const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
-
-    const float x0 = float(data_a[ix + 0]);
-    const float x1 = float(data_a[ix + 1]);
-
-    data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
-    data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
+    // i1 is actually i2*nb2+i1, but the rows are contiguous
+    const uint i1 = gl_GlobalInvocationID.x;
+    rope_norm(i0, i1, pc);
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
new file mode 100644 (file)
index 0000000..82f39ce
--- /dev/null
@@ -0,0 +1,27 @@
+#if !defined(GGML_ROPE_PARAMS)
+#define GGML_ROPE_PARAMS
+
+#include "rte.glsl"
+
+struct rope_params {
+    uint rope_mode;
+    uint ncols;
+    uint n_dims;
+    float freq_scale;
+    uint p_delta_rows;
+    float freq_base;
+    float ext_factor;
+    float attn_factor;
+    float corr_dims[2];
+    float theta_scale;
+    uint has_ff;
+    uint ne02;
+    uint nb01;
+    uint nb02;
+    int sections[4];
+    uint is_imrope;
+    uint is_back;
+    uint set_rows_stride;
+};
+
+#endif // !defined(GGML_ROPE_PARAMS)
index d37d1c1043f8af006bfb5184c18381359763cec5..ea1e0fdb416887a34ff64294b80445d8c47113f1 100644 (file)
@@ -1,47 +1,11 @@
 #version 450
 
 #include "rope_head.glsl"
+#include "rope_funcs.glsl"
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-    uint ne2 = p.ne02;
-
-    if (i0 >= ne0) {
-        return;
-    }
-
-    const uint row_dst = gl_GlobalInvocationID.x;
-
-    const uint row_x     = row_dst % ne1;
-    const uint channel_x = row_dst / ne1;
-
-    const uint idst = row_dst*ne0 + i0/2;
-    const uint ix   = channel_x*p.s2 + row_x*p.s1 + i0/2;
-
-    const int sect_dims = p.sections[0] + p.sections[1];
-    const int sec_w = p.sections[1] + p.sections[0];
-    const uint sector = (i0 / 2) % sect_dims;
-
-    float theta_base = 0.0;
-    if (sector < p.sections[0]) {
-        const uint p0 = sector;
-        theta_base = data_pos[channel_x]*pow(p.theta_scale, p0);
-    }
-    else if (sector >= p.sections[0] && sector < sec_w) {
-        const uint p0 = sector - p.sections[0];
-        theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0);
-    }
-
-    const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
-
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
-
-    const float x0 = float(data_a[ix + 0]);
-    const float x1 = float(data_a[ix + p.n_dims]);
-
-    data_d[idst + 0]        = D_TYPE(x0*cos_theta - x1*sin_theta);
-    data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta);
+    // i1 is actually i2*nb2+i1, but the rows are contiguous
+    const uint i1 = gl_GlobalInvocationID.x;
+    rope_vision(i0, i1, pc);
 }
index bd178875d55f6775a63478295dea10684d1a7c82..c2e42cf006e96daf8e823acbbb394b71b105785c 100644 (file)
@@ -695,6 +695,8 @@ void process_shaders() {
     string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
+    string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}}));
     string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 
@@ -840,25 +842,25 @@ void process_shaders() {
     string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
     string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 
-    string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
-    string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
-    string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
-    string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
-
-    string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
-    string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
-    string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
-    string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
-
-    string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
-    string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
-
-    string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
-    string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+    string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
+    string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
+    string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
+    string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
+    string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
+
+    string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
+    string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
+    string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
+    string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
+    string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
+
+    string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
+    string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
+    string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
+
+    string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
+    string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
+    string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 
     string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
 
index b9ae82eeddd71e1443025a5e59045e90bf5e889d..80216f6f3a2a66a0cc81ac5b7da671236c604291 100644 (file)
@@ -2294,6 +2294,79 @@ struct test_rope_set_rows : public test_case {
     }
 };
 
+// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ROPE (+ GGML_OP_VIEW + GGML_OP_SET_ROWS)
+struct test_rms_norm_mul_rope : public test_case {
+    const std::array<int64_t, 4> ne;
+    const float eps;
+    const bool multi_add; // test a sequence of adds feeding into rms_norm
+    const bool set_rows;
+    int mode;
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "RMS_NORM_MUL_ROPE";
+    }
+
+    bool run_whole_graph() override { return true; }
+
+    std::string vars() override {
+        return VARS_TO_STR5(ne, eps, multi_add, set_rows, mode);
+    }
+
+    test_rms_norm_mul_rope(std::array<int64_t, 4> ne, float eps = 1e-6f, bool multi_add = false,
+                           bool set_rows = false, int mode = GGML_ROPE_TYPE_NORMAL)
+        : ne(ne), eps(eps), multi_add(multi_add), set_rows(set_rows), mode(mode) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
+        ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
+        ggml_tensor * c = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
+
+        if (multi_add) {
+            a = ggml_add(ctx, ggml_add(ctx, a, b), c);
+        }
+
+        a = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
+
+        ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
+
+        ggml_tensor * rope = ggml_rope(ctx, a, pos, ne[0], mode);
+
+        ggml_tensor * out;
+
+        if (set_rows) {
+            ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
+
+            ggml_tensor * dst = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
+            ggml_set_name(dst, "dst");
+
+            ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, ne[2], 1, 1);
+            ggml_set_name(row_idxs, "row_idxs");
+
+            out = ggml_set_rows(ctx, dst, view, row_idxs);
+            ggml_set_name(out, "out");
+        } else {
+            out = rope;
+        }
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) {
+                    continue;
+                }
+
+                init_set_rows_row_ids(t, ne[2]);
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
 // GGML_OP_ARGMAX
 struct test_argmax : public test_case {
     const ggml_type type;
@@ -6751,6 +6824,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         }
     }
 
+    for (auto multi_add : {false, true}) {
+        for (auto set_rows : {false, true}) {
+            for (auto rope : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX}) {
+                test_cases.emplace_back(new test_rms_norm_mul_rope({768, 1, 1, 1}, 1e-6f, multi_add, set_rows, rope));
+                test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 1, 1}, 1e-6f, multi_add, set_rows, rope));
+                test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 5, 1}, 1e-6f, multi_add, set_rows, rope));
+                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 2, 1}, 1e-6f, multi_add, set_rows, rope));
+                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 2, 1}, 1e-6f, multi_add, set_rows, rope));
+                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 50, 1}, 1e-6f, multi_add, set_rows, rope));
+                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 50, 1}, 1e-6f, multi_add, set_rows, rope));
+                test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
+                test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
+            }
+        }
+    }
+
     test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
 
     for (int64_t d_conv : {3, 4}) {