]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: Extend rope fusions to allow mrope (llama/18264)
authorJeff Bolz <redacted>
Mon, 22 Dec 2025 17:03:13 +0000 (11:03 -0600)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 15:52:09 +0000 (17:52 +0200)
Extend the test-backend-ops tests as well.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index c2adca9cba866843eaca1310e27781265e68bc43..a524adbe0cfa4f37223e064180a74c0bab4b6412 100644 (file)
@@ -731,7 +731,7 @@ struct vk_device_struct {
 
     vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
     vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
-    vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
+    vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16;
     vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
     vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
     vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
@@ -4077,6 +4077,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
         ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
     } else {
         ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -4085,6 +4086,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
         ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
     }
 
     for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
@@ -8680,6 +8682,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
                     return ctx->device->pipeline_rope_multi_f32;
                 }
+                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+                    return ctx->device->pipeline_rope_multi_f32_f16;
+                }
                 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
                     return ctx->device->pipeline_rope_multi_f16;
                 }
@@ -13076,9 +13081,9 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
         return false;
     }
 
-    // Only norm/neox shaders have the fusion code
+    // Only norm/neox/mrope shaders have the fusion code
     const int mode = ((const int32_t *) rope->op_params)[2];
-    if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
+    if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) {
         return false;
     }
 
index 9726b722d1e461eb2b8bd1a00e1dc26b6a4ac7fe..aacec984696f01338591d02578161fe1bc9dfded 100644 (file)
@@ -49,8 +49,8 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
     uint idst = i1*ne0 + i0;
     const uint ix = rope_a_coord(i0, i01, i02, p);
 
-    // Fusion optimization: ROPE + VIEW + SET_ROWS..
-    // The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
+    // Fusion optimization: ROPE + VIEW + SET_ROWS.
+    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
     if (p.set_rows_stride != 0) {
         idst = i01*ne0 + i0;
         idst += rope_data_i[i02].x * p.set_rows_stride;
@@ -91,7 +91,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
     uint idst = i1*ne0 + i0/2;
     const uint ix = rope_a_coord(i0/2, i01, i02, p);
 
-    // Fusion optimization: ROPE + VIEW + SET_ROWS..
+    // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
     if (p.set_rows_stride != 0) {
         idst = i01*ne0 + i0/2;
@@ -132,9 +132,16 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
     const uint i01 = i1 % ne1;
     const uint i02 = i1 / ne1;
 
-    const uint idst = i1*ne0 + i0/2;
+    uint idst = i1*ne0 + i0/2;
     const uint ix = rope_a_coord(i0/2, i01, i02, p);
 
+    // Fusion optimization: ROPE + VIEW + SET_ROWS.
+    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
+    if (p.set_rows_stride != 0) {
+        idst = i01*ne0 + i0/2;
+        idst += rope_data_i[i02].x * p.set_rows_stride;
+    }
+
     if (i0 >= p.n_dims) {
         rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
         rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
index 92ad3bcab1e58d7e0d6cb4c26cfdde9c5e64151f..e237a8e102cbfae25118953fe5baadbb61a8d6ac 100644 (file)
@@ -927,6 +927,8 @@ void process_shaders() {
     string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
     string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
     string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
+    string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
+    string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 
     string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
     string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});