]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
SYCL: Add non contiguous support in RMS_NORM and NORM kernels (llama/13611)
authorAkarshan Biswas <redacted>
Mon, 26 May 2025 15:40:36 +0000 (21:10 +0530)
committerGeorgi Gerganov <redacted>
Tue, 27 May 2025 15:03:00 +0000 (18:03 +0300)
* SYCL: Add non contiguous input support to norm kernel

* refactor and add RMS_NORM non contiguous input support

ggml-ci

* restore subgroup reduction for multi-subgroup thread blocks in norm kernels

* Swap grid dims of nsamples and nrows

ggml-ci

* Revert "Swap grid dims of nsamples and nrows"

This reverts commit 43be2d657fec7f7fba54e2cd154106bc0fc45adf.

* restore not required changes
ggml-ci

* address review comments: change it to more like SYCL

* Use a common function to calculate offset

* remove wrap around logic for handling broadcasts

* remove static from calculate_offset fn and use ceil_div

ggml/src/ggml-sycl/common.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/norm.cpp

index 03b6956d4b54be098d7f705829687a03602ffa91..15ee9dc69d149ff8b2339d13a3c4f4dfca48a713 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef GGML_SYCL_COMMON_HPP
 #define GGML_SYCL_COMMON_HPP
 
+#include <cstddef>
 #include <fstream>
 #include <iostream>
 #include <string>
@@ -481,6 +482,19 @@ static __dpct_inline__ float warp_reduce_max(float x,
     return x;
 }
 
+/* Helper for Computing the linear offset of a ggml_tensor given
+per-dimension sizes, strides, and indices */
+template<int N>
+__dpct_inline__ size_t calculate_offset(const std::array<int, N> & strides, const std::array<int, N> & indices) {
+    size_t offset = 0;
+#pragma unroll
+    for (int i = 0; i < N; i++) {
+        auto index_i = indices[i];
+        offset += strides[i] * index_i;
+    }
+    return offset;
+}
+
 // Helper for vec loading aligned data
 template <typename Tp, int n>
 inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
index 134ec78a0b484d739e7ebae56a2eb6e8ed5b56d1..6a53bd12c4efb6ccfdd76529a444c5abbc596231 100644 (file)
@@ -4241,6 +4241,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
 #endif
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
+            return true;
         case GGML_OP_L2_NORM:
         case GGML_OP_GROUP_NORM:
             return ggml_is_contiguous(op->src[0]);
index 4e9f438b46ba6613330ab99a0a2f4f0e083896fb..4ec1416849c7e718f27f5cd40ccf1946ced3926f 100644 (file)
@@ -1,40 +1,50 @@
 #include "norm.hpp"
+#include "ggml-sycl/common.hpp"
+#include "ggml-sycl/presets.hpp"
 
-static void norm_f32(const float* x, float* dst, const int ncols, const float eps,
-    const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-        item_ct1.get_local_id(1);
-    const int tid = item_ct1.get_local_id(2);
+static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+        const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
+
+    const int nrows = item_ct1.get_group_range(2);
+    const int nchannels = item_ct1.get_group_range(1);
 
     const int nthreads = item_ct1.get_local_range(2);
+    const int sample  = item_ct1.get_group(0);
+    const int channel = item_ct1.get_group(1);
+    const int row     = item_ct1.get_group(2);
+
+    const int tid = item_ct1.get_local_id(2);
     const int nwarps = nthreads / WARP_SIZE;
+
+    const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
+    const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
+
+    x += strided_offset;
+    dst += packed_offset;
+
     sycl::float2 mean_var = sycl::float2(0.f, 0.f);
 
     for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row * ncols + col];
+        const float xi = x[col];
         mean_var.x() += xi;
         mean_var.y() += xi * xi;
     }
 
     // sum up partial sums
     mean_var = warp_reduce_sum(mean_var, item_ct1);
-    if (block_size > WARP_SIZE) {
-
-        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
-        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = mean_var;
+    if  (block_size > WARP_SIZE) {
+        const auto sub_group = item_ct1.get_sub_group();
+        const auto sg_id = sub_group.get_group_linear_id();
+        const auto wi_in_sg = sub_group.get_local_linear_id();
+        if (wi_in_sg == 0) {
+            s_sum[sg_id] = mean_var;
         }
-        /*
-        DPCT1118:0: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
         item_ct1.barrier(sycl::access::fence_space::local_space);
         mean_var = 0.f;
-        size_t nreduce = nwarps / WARP_SIZE;
+        const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
         for (size_t i = 0; i < nreduce; i += 1)
         {
-            mean_var += s_sum[lane_id + i * WARP_SIZE];
+            mean_var += s_sum[wi_in_sg + i * WARP_SIZE];
         }
         mean_var = warp_reduce_sum(mean_var, item_ct1);
     }
@@ -44,7 +54,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
     const float inv_std = sycl::rsqrt(var + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
+        dst[col] = (x[col] - mean) * inv_std;
     }
 }
 
@@ -135,39 +145,51 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
     }
 }
 
-static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps,
-    const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-        item_ct1.get_local_id(1);
-    const int tid = item_ct1.get_local_id(2);
+static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+        const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
+
+    const int nrows = item_ct1.get_group_range(2);
+    const int nchannels = item_ct1.get_group_range(1);
+
+    const int sample  = item_ct1.get_group(0);
+    const int channel = item_ct1.get_group(1);
+    const int row     = item_ct1.get_group(2);
+
     const int nthreads = item_ct1.get_local_range(2);
+
+    const int tid = item_ct1.get_local_id(2);
     const int nwarps = nthreads / WARP_SIZE;
+
+    const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
+    const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
+
+    x   += strided_offset;
+    dst += packed_offset;
+
+
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row * ncols + col];
+        const float xi = x[col];
         tmp += xi * xi;
     }
 
     // sum up partial sums
     tmp = warp_reduce_sum(tmp, item_ct1);
     if (block_size > WARP_SIZE) {
-
-        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
-        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
+        const auto sub_group = item_ct1.get_sub_group();
+        const auto sg_id = sub_group.get_group_linear_id();
+        const auto wi_in_sg = sub_group.get_local_linear_id();
+        if (wi_in_sg == 0) {
+            s_sum[sg_id] = tmp;
         }
-        /*
-        DPCT1118:3: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
+
         item_ct1.barrier(sycl::access::fence_space::local_space);
-        size_t nreduce = nwarps / WARP_SIZE;
+        const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
         tmp = 0.f;
         for (size_t i = 0; i < nreduce; i += 1)
         {
-            tmp += s_sum[lane_id + i * WARP_SIZE];
+            tmp += s_sum[wi_in_sg + i * WARP_SIZE];
         }
         tmp = warp_reduce_sum(tmp, item_ct1);
     }
@@ -176,7 +198,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
     const float scale = sycl::rsqrt(mean + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[row * ncols + col] = scale * x[row * ncols + col];
+        dst[col] = scale * x[col];
     }
 }
 
@@ -224,20 +246,20 @@ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float
     }
 }
 
-static void norm_f32_sycl(const float* x, float* dst, const int ncols,
-    const int nrows, const float eps,
-    queue_ptr stream, int device) {
+static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
+        const float eps, queue_ptr stream, int device) {
+
+    const sycl::range<3> global_dims(nsamples, nchannels, nrows);
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
         const sycl::range<3> block_dims(1, 1, WARP_SIZE);
         stream->submit([&](sycl::handler& cgh) {
             cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                    block_dims),
+                sycl::nd_range<3>(global_dims * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
-                    norm_f32(x, dst, ncols, eps, item_ct1,
-                        nullptr, WARP_SIZE);
+                    norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
                 });
             });
     }
@@ -252,15 +274,12 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
         */
         stream->submit([&](sycl::handler& cgh) {
             sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
-                sycl::range<1>(work_group_size / WARP_SIZE), cgh);
-
+                            sycl::range<1>(work_group_size / WARP_SIZE), cgh);
             cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                    block_dims),
+                sycl::nd_range<3>(global_dims * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
-                    norm_f32(x, dst, ncols, eps, item_ct1,
-                        get_pointer(s_sum_acc_ct1), work_group_size);
+                    norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
                 });
             });
     }
@@ -313,21 +332,20 @@ static void group_norm_f32_sycl(const float* x, float* dst,
     }
 }
 
-static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
-    const int nrows, const float eps,
-    queue_ptr stream, int device) {
+static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
+
+    const sycl::range<3> global_dims(nsamples, nchannels, nrows);
     if (ncols < 1024) {
         const sycl::range<3> block_dims(1, 1, WARP_SIZE);
         stream->submit([&](sycl::handler& cgh) {
             cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                    block_dims),
+                sycl::nd_range<3>(global_dims * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
-                    rms_norm_f32(x, dst, ncols, eps, item_ct1,
-                        nullptr, WARP_SIZE);
+                    rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
                 });
             });
     }
@@ -344,12 +362,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
             sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
                 cgh);
             cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                    block_dims),
+                sycl::nd_range<3>(global_dims * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
-                    rms_norm_f32(x, dst, ncols, eps, item_ct1,
-                        get_pointer(s_sum_acc_ct1), work_group_size);
+                    rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
                 });
             });
     }
@@ -398,12 +414,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
 }
 
 void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
+    const ggml_tensor * src0 = dst->src[0];
 
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
-    const int64_t ne00 = dst->src[0]->ne[0];
-    const int64_t nrows = ggml_nrows(dst->src[0]);
+    GGML_TENSOR_UNARY_OP_LOCALS
     dpct::queue_ptr main_stream = ctx.stream();
     SYCL_CHECK(ggml_sycl_set_device(ctx.device));
     const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
@@ -411,8 +427,14 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
-
-    norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
+    GGML_ASSERT(eps >= 0.0f);
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
+
+    norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
 }
 
 void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
@@ -436,11 +458,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
 
 void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 
+    const ggml_tensor * src0 = dst->src[0];
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
-    const int64_t ne00 = dst->src[0]->ne[0];
-    const int64_t nrows = ggml_nrows(dst->src[0]);
     dpct::queue_ptr main_stream = ctx.stream();
     SYCL_CHECK(ggml_sycl_set_device(ctx.device));
 
@@ -450,7 +471,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
-    rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
+    GGML_TENSOR_UNARY_OP_LOCALS
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
+    rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
 }
 
 void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {