]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sycl: add RMS_NORM_BACK operation support (#16808)
authorYaelLogic <redacted>
Wed, 29 Oct 2025 06:14:39 +0000 (08:14 +0200)
committerGitHub <redacted>
Wed, 29 Oct 2025 06:14:39 +0000 (14:14 +0800)
* sycl: add RMS_NORM_BACK operation support

* sycl: rms_norm_back: add dual reduction paths (FP64 and FP32) and savepoint before further changes

* sycl: add RMS_NORM_BACK support

Implement RMS_NORM_BACK for the SYCL backend using FP32 compensated parallel reduction. Minimal docs updates (ops.md / SYCL.csv).

* revert: restore .gitignore and tools/run/CMakeLists.txt to upstream

* revert: restore tests/CMakeLists.txt to upstream

* sycl: optimize rms_norm_back

* fix: restore SYCL.csv to correct state with RMS_NORM_BACK support

* Update ggml/src/ggml-sycl/norm.cpp

Co-authored-by: Neo Zhang Jianyu <redacted>
* fix: remove trailing whitespace and add missing newline (EditorConfig)

---------

Co-authored-by: Neo Zhang Jianyu <redacted>
docs/ops.md
docs/ops/SYCL.csv
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/norm.cpp
ggml/src/ggml-sycl/norm.hpp

index dfd1cfab6a8b28193d8c83dbeac366275168e01a..3738a48072832339d5b8ac6b3ecbe5ba53d4b3ca 100644 (file)
@@ -79,7 +79,7 @@ Legend:
 |                           REPEAT | โŒ | โœ… | โœ… | ๐ŸŸก | โœ… | ๐ŸŸก | โœ… | ๐ŸŸก | โŒ |
 |                      REPEAT_BACK | โŒ | โŒ | โœ… | โœ… | โŒ | โŒ | โŒ | โœ… | โŒ |
 |                         RMS_NORM | โŒ | โœ… | โœ… | โœ… | ๐ŸŸก | โœ… | โœ… | โœ… | โŒ |
-|                    RMS_NORM_BACK | รข\9d\8c | รข\9d\8c | รข\9c\85 | รข\9c\85 | รข\9d\8c | รข\9d\8c | รข\9d\8c | โœ… | โŒ |
+|                    RMS_NORM_BACK | รข\9d\8c | รข\9d\8c | รข\9c\85 | รข\9c\85 | รข\9d\8c | รข\9d\8c | รข\9c\85 | โœ… | โŒ |
 |                 RMS_NORM_MUL_ADD | โŒ | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ |
 |                             ROLL | โŒ | โŒ | โœ… | โŒ | โŒ | โŒ | โŒ | โœ… | โŒ |
 |                             ROPE | โŒ | ๐ŸŸก | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ |
index fe6876357f359a6ac41c4179b28fcfbd18e8e66e..101e80f64c662eb88816cdbb8aa832f07c8c5b25 100644 (file)
 "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","SYCL"
 "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","1","yes","SYCL"
 "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","SYCL"
-"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","SYCL"
+"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","1","yes","SYCL"
 "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
 "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","1","yes","SYCL"
 "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","SYCL"
 "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","1","yes","SYCL"
 "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","SYCL"
-"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","SYCL"
+"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","1","yes","SYCL"
 "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
 "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","1","yes","SYCL"
 "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","SYCL"
 "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","1","yes","SYCL"
 "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","SYCL"
-"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","SYCL"
+"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","1","yes","SYCL"
 "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
 "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","1","yes","SYCL"
 "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","SYCL"
 "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","1","yes","SYCL"
 "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","SYCL"
-"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","SYCL"
+"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","1","yes","SYCL"
 "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
 "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","SYCL"
 "SYCL0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=0.000000,broadcast=0,multi_add=0","support","1","yes","SYCL"
index 328d1a71b75802c50a8ea032578af42d8a20df10..c97c5899435b18160260f6f9b4a26702a3ceb1f0 100644 (file)
@@ -42,6 +42,7 @@
 #include "ggml-sycl/backend.hpp"
 #include "ggml-sycl/common.hpp"
 #include "ggml-sycl/element_wise.hpp"
+#include "ggml-sycl/norm.hpp"
 #include "ggml-sycl/presets.hpp"
 #include "ggml-sycl/gemm.hpp"
 #include "ggml-sycl/set_rows.hpp"
@@ -2637,6 +2638,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
     ggml_sycl_op_rms_norm(ctx, dst);
 }
 
+static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
+    ggml_sycl_op_rms_norm_back(ctx, dst);
+}
+
 static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
     ggml_sycl_op_l2_norm(ctx, dst);
@@ -3827,6 +3833,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_LEAKY_RELU:
             ggml_sycl_leaky_relu(ctx, dst);
             break;
+        case GGML_OP_RMS_NORM_BACK:
+            ggml_sycl_rms_norm_back(ctx, dst);
+            break;
         case GGML_OP_RMS_NORM:
             ggml_sycl_rms_norm(ctx, dst);
             break;
@@ -4571,6 +4580,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_RMS_NORM:
             return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
+        case GGML_OP_RMS_NORM_BACK:
+            return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
         case GGML_OP_SCALE:
             return true;
         case GGML_OP_CONT:
index 4ec1416849c7e718f27f5cd40ccf1946ced3926f..823d3a4828cc925cdeea2cb5c5434d51f371d1f5 100644 (file)
@@ -480,6 +480,162 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
 }
 
+void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
+
+    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz
+    GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x
+    GGML_ASSERT(dst->type         == GGML_TYPE_F32);
+
+    float eps = 1e-5f;
+    std::memcpy(&eps, dst->op_params, sizeof(float));
+    if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f;
+
+    const float * g_base  = static_cast<const float *>(dst->src[0]->data); // dz
+    const float * x_base  = static_cast<const float *>(dst->src[1]->data); // x
+          float * dx_base = static_cast<      float *>(dst->data);
+
+    const int64_t D  = dst->ne[0];
+    const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3;
+    const int64_t N  = ggml_nrows(dst);
+    if (D == 0 || N == 0) return;
+
+    const ggml_tensor *G = dst->src[0];
+    const ggml_tensor *X = dst->src[1];
+    const int ts = (int) ggml_type_size(X->type);
+    GGML_ASSERT((size_t) X->nb[0]   == (size_t) ts);
+    GGML_ASSERT((size_t) G->nb[0]   == (size_t) ts);
+    GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts);
+
+    const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts;
+    const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts;
+    const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts;
+
+    dpct::queue_ptr q = ctx.stream();
+
+    // work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D
+    const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device];
+    auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; };
+    int wg_cap = 256;
+    if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg);
+    int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min<int64_t>(D, wg_cap), WARP_SIZE), wg_cap));
+
+    // FP32 path: per-thread compensated accumulation + hierarchical reduction
+    q->submit([&](sycl::handler &cgh) {
+        const int nwarps_loc = std::max(1, WG / WARP_SIZE);
+        // store one partial value per warp (xx and xg) for cross-warp reduction
+        auto l_xx   = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
+        auto l_xg   = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
+
+        cgh.parallel_for(
+            sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG),
+                              sycl::range<3>(1, 1, WG)),
+            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                const int row = item_ct1.get_group(2);
+                const int tid = item_ct1.get_local_id(2);
+
+                const int64_t i1 = row % n1;
+                const int64_t i2 = (row / n1) % n2;
+                const int64_t i3 = row / (n1 * n2);
+
+                const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1;
+                const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;
+                float *__restrict d_row       = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;
+
+                // per-thread accumulation (compensated by default)
+                float sum_xx = 0.f, sum_xg = 0.f;
+#ifndef GGML_SYCL_RMS_BACK_FAST
+                float c_xx = 0.f, c_xg = 0.f;
+#endif
+                for (int64_t col = tid; col < D; col += WG) {
+                    const float xv = x_row[col];
+                    const float gv = g_row[col];
+#ifdef GGML_SYCL_RMS_BACK_FAST
+                    sum_xx += xv * xv;
+                    sum_xg += xv * gv;
+#else
+                    float y1 = xv * xv - c_xx;
+                    float t1 = sum_xx + y1;
+                    c_xx = (t1 - sum_xx) - y1;
+                    sum_xx = t1;
+
+                    float y2 = xv * gv - c_xg;
+                    float t2 = sum_xg + y2;
+                    c_xg = (t2 - sum_xg) - y2;
+                    sum_xg = t2;
+#endif
+                }
+
+                // warp-level reduction
+                sycl::float2 xx = sycl::float2(sum_xx,
+#ifndef GGML_SYCL_RMS_BACK_FAST
+                    c_xx
+#else
+                    0.f
+#endif
+                );
+                sycl::float2 xg = sycl::float2(sum_xg,
+#ifndef GGML_SYCL_RMS_BACK_FAST
+                    c_xg
+#else
+                    0.f
+#endif
+                );
+                xx = warp_reduce_sum(xx, item_ct1);
+                xg = warp_reduce_sum(xg, item_ct1);
+
+                // cross-warp reduction using local memory (single barrier)
+                const auto sub_group = item_ct1.get_sub_group();
+                const auto sg_id     = sub_group.get_group_linear_id();
+                const auto wi_in_sg  = sub_group.get_local_linear_id();
+                const int nthreads   = item_ct1.get_local_range(2);
+                const int nwarps     = nthreads / WARP_SIZE;
+
+                sycl::float2 xx_total = xx;
+                sycl::float2 xg_total = xg;
+                if (nwarps > 1) {
+                    if (wi_in_sg == 0) {
+                        l_xx[sg_id] = xx;
+                        l_xg[sg_id] = xg;
+                    }
+                    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+                    if (sg_id == 0) {
+                        const unsigned wi_u = wi_in_sg;
+                        sycl::float2 xx_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f);
+                        sycl::float2 xg_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f);
+                        xx_total = warp_reduce_sum(xx_first, item_ct1);
+                        xg_total = warp_reduce_sum(xg_first, item_ct1);
+                    } else {
+                        // other subgroups keep their local totals; they'll be ignored
+                        xx_total = xx;
+                        xg_total = xg;
+                    }
+                    // ensure all threads see the first-subgroup result via broadcast below
+                }
+
+                // compute inv_r and coeff once per row and broadcast to the whole work-group
+                float inv_r = 0.f;
+                float coeff = 0.f;
+                if (tid == 0) {
+                    const float sum_xx_f  = xx_total.x() + xx_total.y();
+                    const float sum_xdz_f = xg_total.x() + xg_total.y();
+                    const float mean_eps  = sum_xx_f / (float) D + eps;
+                    const float sum_eps   = sum_xx_f + eps * (float) D;
+                    inv_r = sycl::rsqrt(mean_eps);
+                    coeff = -sum_xdz_f / sum_eps;
+                }
+                inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r);
+                coeff = sycl::group_broadcast(item_ct1.get_group(), coeff);
+
+                for (int64_t col = tid; col < D; col += WG) {
+                    d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;
+                }
+            });
+    });
+
+}
+
 void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
 
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
index 612cd67cf9183d502b88004dc950087ca5562dd4..8cb885eb2eed541c620f252fbacaa1493b853627 100644 (file)
@@ -19,6 +19,8 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
 
 void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
 
+void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
+
 void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
 
 void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);