]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: support multi/vision rope, and noncontiguous rope (#11902)
authorJeff Bolz <redacted>
Sun, 16 Feb 2025 07:52:23 +0000 (01:52 -0600)
committerGitHub <redacted>
Sun, 16 Feb 2025 07:52:23 +0000 (08:52 +0100)
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp
ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 68f2ea14bafd3f66908ea3751364c79f8d8fba2e..88f31c1ef8b2f7dc4ff950715e93782c108b6dd6 100644 (file)
@@ -251,6 +251,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
     vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
     vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
+    vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
+    vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
     vk_pipeline pipeline_argsort_f32;
     vk_pipeline pipeline_sum_rows_f32;
     vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@@ -494,6 +496,10 @@ struct vk_op_rope_push_constants {
     float corr_dims[2];
     float theta_scale;
     uint32_t has_ff;
+    uint32_t ne02;
+    uint32_t s1;
+    uint32_t s2;
+    int32_t sections[4];
 };
 
 struct vk_op_soft_max_push_constants {
@@ -2180,13 +2186,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 
     if (device->float_controls_rte_fp16) {
         ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
     } else {
         ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
     }
 
     ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
@@ -5307,6 +5319,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         {
             const int mode = ((const int32_t *) dst->op_params)[2];
             const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+            const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+            const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 
             if (is_neox) {
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@@ -5315,6 +5329,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
                     return ctx->device->pipeline_rope_neox_f16;
                 }
+            } else if (is_mrope && !is_vision) {
+                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+                    return ctx->device->pipeline_rope_multi_f32;
+                }
+                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+                    return ctx->device->pipeline_rope_multi_f16;
+                }
+            } else if (is_vision) {
+                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+                    return ctx->device->pipeline_rope_vision_f32;
+                }
+                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+                    return ctx->device->pipeline_rope_vision_f16;
+                }
             } else {
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
                     return ctx->device->pipeline_rope_norm_f32;
@@ -5385,6 +5413,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_REPEAT:
+    case GGML_OP_ROPE:
         return true;
     default:
         return false;
@@ -6149,7 +6178,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
 
 static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
     const int n_dims        = ((int32_t *) dst->op_params)[1];
-    // const int mode          = ((int32_t *) dst->op_params)[2];
+    const int mode          = ((int32_t *) dst->op_params)[2];
     // const int n_ctx         = ((int32_t *) dst->op_params)[3];
     const int n_ctx_orig    = ((int32_t *) dst->op_params)[4];
     const float freq_base   = ((float *)   dst->op_params)[5];
@@ -6158,16 +6187,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
     const float attn_factor = ((float *)   dst->op_params)[8];
     const float beta_fast   = ((float *)   dst->op_params)[9];
     const float beta_slow   = ((float *)   dst->op_params)[10];
+    int sections[4] {};
+    if (mode & GGML_ROPE_TYPE_MROPE) {
+        memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
+    }
 
     float corr_dims[2];
     ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
+    uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
+    uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
+
     ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
         (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
         freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
-        src2 != nullptr,
+        src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
+        sections[0], sections[1], sections[2], sections[3],
     }, dryrun);
 }
 
@@ -8264,16 +8301,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_REPEAT:
             return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
         case GGML_OP_ROPE:
-            {
-                const int mode = ((const int32_t *) op->op_params)[2];
-                if (mode & GGML_ROPE_TYPE_MROPE) {
-                    return false;
-                }
-                if (mode & GGML_ROPE_TYPE_VISION) {
-                    return false;
-                }
-                return ggml_is_contiguous(op->src[0]);
-            }
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -8831,7 +8858,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         const float attn_factor     = ((float *) tensor->op_params)[8];
         const float beta_fast       = ((float *) tensor->op_params)[9];
         const float beta_slow       = ((float *) tensor->op_params)[10];
-        tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+        if (mode & GGML_ROPE_TYPE_MROPE) {
+            int32_t *sections = ((int32_t *) tensor->op_params) + 11;
+            tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+        } else {
+            tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+        }
     } else if (tensor->op == GGML_OP_UNARY) {
         switch (ggml_get_unary_op(tensor)) {
         case GGML_UNARY_OP_SILU:
index 574b51ca553898ed90a05b7b9d0d92745545dc22..38075b75557f906ca3490401d3261b9e560f6d74 100644 (file)
@@ -25,6 +25,10 @@ layout (push_constant) uniform parameter {
     float corr_dims[2];
     float theta_scale;
     uint has_ff;
+    uint ne02;
+    uint s1;
+    uint s2;
+    int sections[4];
 } p;
 
 float rope_yarn_ramp(const float low, const float high, const uint i0) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
new file mode 100644 (file)
index 0000000..4f5b1a0
--- /dev/null
@@ -0,0 +1,60 @@
+#version 450
+
+#include "rope_head.comp"
+
+void main() {
+    const uint i0 = 2*gl_GlobalInvocationID.y;
+    uint ne0 = p.ncols;
+    uint ne1 = p.p_delta_rows;
+    uint ne2 = p.ne02;
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const uint row_dst = gl_GlobalInvocationID.x;
+
+    if (i0 >= p.n_dims) {
+        const uint i = row_dst*ne0 + i0;
+
+        data_d[i + 0] = data_a[i + 0];
+        data_d[i + 1] = data_a[i + 1];
+
+        return;
+    }
+
+    const uint row_x     = row_dst % ne1;
+    const uint channel_x = row_dst / ne1;
+
+    const uint idst = row_dst*ne0 + i0/2;
+    const uint ix   = channel_x*p.s2 + row_x*p.s1 + i0/2;
+
+    const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
+    const int sec_w = p.sections[1] + p.sections[0];
+    const uint sector = (i0 / 2) % sect_dims;
+
+    float theta_base = 0.0;
+    if (sector < p.sections[0]) {
+        theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
+    }
+    else if (sector >= p.sections[0] && sector < sec_w) {
+        theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+    }
+    else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
+        theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+    }
+    else if (sector >= sec_w + p.sections[2]) {
+        theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+    }
+
+    const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
+
+    const float x0 = float(data_a[ix + 0]);
+    const float x1 = float(data_a[ix + p.n_dims/2]);
+
+    data_d[idst + 0]          = D_TYPE(x0*cos_theta - x1*sin_theta);
+    data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
+}
index 83b46b69b2a7f22cf18f1b3851658c6f87f4cce0..db775c456cae82f2767b159f9101b11eee4abc12 100644 (file)
@@ -3,15 +3,18 @@
 #include "rope_head.comp"
 
 void main() {
-    const uint col = gl_GlobalInvocationID.y * 2;
-    const uint row = gl_GlobalInvocationID.x;
+    const uint i0 = 2*gl_GlobalInvocationID.y;
+    uint ne0 = p.ncols;
+    uint ne1 = p.p_delta_rows;
 
-    if (col >= p.ncols) {
+    if (i0 >= ne0) {
         return;
     }
 
-    if (col >= p.n_dims) {
-        const uint i = row*p.ncols + col;
+    const uint row_dst = gl_GlobalInvocationID.x;
+
+    if (i0 >= p.n_dims) {
+        const uint i = row_dst*ne0 + i0;
 
         data_d[i + 0] = data_a[i + 0];
         data_d[i + 1] = data_a[i + 1];
@@ -19,19 +22,22 @@ void main() {
         return;
     }
 
-    const uint i  = row*p.ncols + col/2;
-    const uint i2 = row/p.p_delta_rows;
+    const uint row_x     = row_dst % ne1;
+    const uint channel_x = row_dst / ne1;
+
+    const uint idst = row_dst*ne0 + i0/2;
+    const uint ix   = channel_x*p.s2 + row_x*p.s1 + i0/2;
 
-    const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
+    const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
 
-    const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
+    const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
 
     float cos_theta, sin_theta;
-    rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
+    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
 
-    const float x0 = float(data_a[i + 0]);
-    const float x1 = float(data_a[i + p.n_dims/2]);
+    const float x0 = float(data_a[ix + 0]);
+    const float x1 = float(data_a[ix + p.n_dims/2]);
 
-    data_d[i + 0]        = D_TYPE(x0*cos_theta - x1*sin_theta);
-    data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
+    data_d[idst + 0]          = D_TYPE(x0*cos_theta - x1*sin_theta);
+    data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
 }
index e416ad93897063f710d70d8225a6250fd6b6d875..4ad35e549d77fa9db0317f6caa4b4427089336ed 100644 (file)
@@ -3,15 +3,18 @@
 #include "rope_head.comp"
 
 void main() {
-    const uint col = gl_GlobalInvocationID.y * 2;
-    const uint row = gl_GlobalInvocationID.x;
+    const uint i0 = 2*gl_GlobalInvocationID.y;
+    uint ne0 = p.ncols;
+    uint ne1 = p.p_delta_rows;
 
-    if (col >= p.ncols) {
+    if (i0 >= ne0) {
         return;
     }
 
-    if (col >= p.n_dims) {
-        const uint i = row*p.ncols + col;
+    const uint row_dst = gl_GlobalInvocationID.x;
+
+    if (i0 >= p.n_dims) {
+        const uint i = row_dst*ne0 + i0;
 
         data_d[i + 0] = data_a[i + 0];
         data_d[i + 1] = data_a[i + 1];
@@ -19,19 +22,22 @@ void main() {
         return;
     }
 
-    const uint i = row*p.ncols + col;
-    const uint i2 = row/p.p_delta_rows;
+    const uint row_x     = row_dst % ne1;
+    const uint channel_x = row_dst / ne1;
+
+    const uint idst = row_dst*ne0 + i0;
+    const uint ix   = channel_x*p.s2 + row_x*p.s1 + i0;
 
-    const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
+    const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
 
-    const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
+    const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
 
     float cos_theta, sin_theta;
-    rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
+    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
 
-    const float x0 = float(data_a[i + 0]);
-    const float x1 = float(data_a[i + 1]);
+    const float x0 = float(data_a[ix + 0]);
+    const float x1 = float(data_a[ix + 1]);
 
-    data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
-    data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
+    data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
+    data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
new file mode 100644 (file)
index 0000000..cedacc4
--- /dev/null
@@ -0,0 +1,47 @@
+#version 450
+
+#include "rope_head.comp"
+
+void main() {
+    const uint i0 = 2*gl_GlobalInvocationID.y;
+    uint ne0 = p.ncols;
+    uint ne1 = p.p_delta_rows;
+    uint ne2 = p.ne02;
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const uint row_dst = gl_GlobalInvocationID.x;
+
+    const uint row_x     = row_dst % ne1;
+    const uint channel_x = row_dst / ne1;
+
+    const uint idst = row_dst*ne0 + i0/2;
+    const uint ix   = channel_x*p.s2 + row_x*p.s1 + i0/2;
+
+    const int sect_dims = p.sections[0] + p.sections[1];
+    const int sec_w = p.sections[1] + p.sections[0];
+    const uint sector = (i0 / 2) % sect_dims;
+
+    float theta_base = 0.0;
+    if (sector < p.sections[0]) {
+        const uint p0 = sector;
+        theta_base = data_pos[channel_x]*pow(p.theta_scale, p0);
+    }
+    else if (sector >= p.sections[0] && sector < sec_w) {
+        const uint p0 = sector - p.sections[0];
+        theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0);
+    }
+
+    const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
+
+    const float x0 = float(data_a[ix + 0]);
+    const float x1 = float(data_a[ix + p.n_dims]);
+
+    data_d[idst + 0]        = D_TYPE(x0*cos_theta - x1*sin_theta);
+    data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta);
+}
index 601cd4e7d30df64afbf88f285dd40afe99c58a07..ba9163af27ad965f1e7499380a5bc35a40d143ec 100644 (file)
@@ -491,6 +491,14 @@ void process_shaders() {
     string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
     string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
 
+    string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+    string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+
+    string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+    string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
+
     string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
 
     string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));