]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: support noncontiguous rms_norm (llama/13031)
authorJeff Bolz <redacted>
Sun, 20 Apr 2025 08:50:02 +0000 (03:50 -0500)
committerGeorgi Gerganov <redacted>
Thu, 24 Apr 2025 17:39:16 +0000 (20:39 +0300)
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp

index 2f6d03c939eb576270a98d6d08a4bd544fa2c820..39f3cd343ac450d39092d65752ceda103e664e39 100644 (file)
@@ -2397,7 +2397,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
@@ -6006,6 +6006,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_REPEAT:
     case GGML_OP_REPEAT_BACK:
     case GGML_OP_ROPE:
+    case GGML_OP_RMS_NORM:
         return true;
     default:
         return false;
@@ -6216,7 +6217,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
 
     switch (op) {
     case GGML_OP_NORM:
-    case GGML_OP_RMS_NORM:
     case GGML_OP_RMS_NORM_BACK:
     case GGML_OP_L2_NORM:
     case GGML_OP_SOFT_MAX:
@@ -6233,6 +6233,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
                 elements = { nr, 1, 1 };
             }
         } break;
+    case GGML_OP_RMS_NORM:
+        elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
+        break;
+
     case GGML_OP_SUM:
         // We use GGML_OP_SUM_ROWS with 1 row.
         elements = { 1, 1, 1 };
@@ -6883,7 +6887,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
 
 static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        op_params[0], 0.0f,
+        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+    }, dryrun);
 }
 
 static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9388,10 +9402,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
+        case GGML_OP_RMS_NORM:
             return true;
         case GGML_OP_NORM:
         case GGML_OP_GROUP_NORM:
-        case GGML_OP_RMS_NORM:
         case GGML_OP_L2_NORM:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_ADD:
index b554400ba393f92227cf1ccd7c40549890e30a32..deb8ee9960f58240fb5ea33b06809ffb2f4632e9 100644 (file)
@@ -1,6 +1,6 @@
 #version 450
 
-#include "generic_head.comp"
+#include "generic_unary_head.comp"
 #include "types.comp"
 
 #extension GL_EXT_control_flow_attributes : enable
@@ -8,19 +8,29 @@
 
 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
 
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
 shared FLOAT_TYPE sum[BLOCK_SIZE];
 
 void main() {
-    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
-    const uint tid = gl_LocalInvocationID.x;
+    const uint ncols     = p.ne00;
+    const uint nrows     = gl_NumWorkGroups.x;
+    const uint nchannels = gl_NumWorkGroups.y;
+
+    const uint row       = gl_WorkGroupID.x;
+    const uint channel   = gl_WorkGroupID.y;
+    const uint samp      = gl_WorkGroupID.z;
+    const uint tid       = gl_LocalInvocationID.x;
+
+    const uint stride_row       = p.nb01;
+    const uint stride_channel   = p.nb02;
+    const uint stride_sample    = p.nb03;
+
+    uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
+    uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
 
     sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
 
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
+    [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
         sum[tid] += xi * xi;
     }
 
@@ -33,10 +43,10 @@ void main() {
         barrier();
     }
 
-    const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
+    const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
     const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
 
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
+    [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+        data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
     }
 }