]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: support L2_NORM with contiguous rows (llama/19604)
authorJeff Bolz <redacted>
Sat, 14 Feb 2026 05:42:04 +0000 (21:42 -0800)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/l2_norm.comp
tests/test-backend-ops.cpp

index a9f75f0d00dc98b504db2da8125018c5557b5882..114992da08d744e1941aa5245afa17b434b77ae5 100644 (file)
@@ -4082,7 +4082,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     }
 
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -10639,8 +10639,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
 }
 
 static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
+    const float * op_params = (const float *)dst->op_params;
+    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
+    p.param1 = op_params[0];
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
 }
 
 static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -14912,8 +14914,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             return true;
         case GGML_OP_NORM:
         case GGML_OP_GROUP_NORM:
-        case GGML_OP_L2_NORM:
             return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_L2_NORM:
+            return ggml_is_contiguous_rows(op->src[0]) &&
+                   op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         case GGML_OP_ADD:
         case GGML_OP_SUB:
         case GGML_OP_MUL:
index 83ef2f879584577d3c074b6bd1fd42fd329e087b..7d0a1de0df9bd5ac5ba0a05921ae796816ed3a08 100644 (file)
@@ -1,6 +1,6 @@
 #version 450
 
-#include "generic_head.glsl"
+#include "generic_unary_head.glsl"
 #include "types.glsl"
 
 #extension GL_EXT_control_flow_attributes : enable
@@ -8,19 +8,22 @@
 
 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
 
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
 shared FLOAT_TYPE sum[BLOCK_SIZE];
 
 void main() {
     const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
     const uint tid = gl_LocalInvocationID.x;
 
+    const uint i3 = row / (p.ne11 * p.ne12);
+    const uint i3_offset = i3 * p.ne12 * p.ne11;
+    const uint i2 = (row - i3_offset) / p.ne11;
+    const uint i2_offset = i2 * p.ne11;
+    const uint i1 = row - i3_offset - i2_offset;
+
     sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
 
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
+    [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
+        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
         sum[tid] += xi * xi;
     }
 
@@ -35,7 +38,7 @@ void main() {
 
     const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
 
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
+    [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
+        data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
     }
 }
index a50c569b820ffef84a6a1e34796d8eb41c09d922..d818ed67d65cc9f9b43f4e532440949aab1a87d0 100644 (file)
@@ -5821,20 +5821,27 @@ struct test_l2_norm : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
     const float eps;
+    bool v;
 
     std::string vars() override {
-        return VARS_TO_STR2(type, ne);
+        return VARS_TO_STR4(type, ne, eps, v);
     }
 
     test_l2_norm(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {64, 64, 320, 1},
-            float eps = 1e-12f)
-        : type(type), ne(ne), eps(eps) {}
+            float eps = 1e-12f,
+            bool v = false)
+        : type(type), ne(ne), eps(eps), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_set_name(a, "a");
 
+        if (v) {
+            a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view of a");
+        }
+
         ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
         ggml_set_name(out, "out");
 
@@ -7596,7 +7603,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                 test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
             }
             test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
-            test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
+            test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
+            test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
         }
     }