]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Fix the sub group size of Intel (llama/8106)
authorluoyu-intel <redacted>
Tue, 2 Jul 2024 02:16:00 +0000 (02:16 +0000)
committerGeorgi Gerganov <redacted>
Mon, 8 Jul 2024 11:53:55 +0000 (14:53 +0300)
* use warp_size macro for all sycl kernels

* fix mask of permute_sub_group_by_xor

* fix rms_norm with correct warp number

* fix rms_norm_f32/group_norm_f32

* move norm to norm.cpp file

* fix quantize bug

* fix mmvq's batch size

ggml/src/CMakeLists.txt
ggml/src/ggml-sycl.cpp

index d0f4097d8cd0c84935f878a27826fe00e222e844..a18198f1693e59c7be395ab8579b52c02a75010d 100644 (file)
@@ -486,9 +486,11 @@ if (GGML_SYCL)
     add_compile_options(-I./) #include DPCT
 
     set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing")
-    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
     if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
         set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
+        add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
+    else()
+        add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
     endif()
 
     file(GLOB   GGML_HEADERS_SYCL "ggml-sycl/*.hpp")
index 30d8a5b33b61335cdbf6fa6e02a3a7191708a378..76bad57e2320b3d5d896d08da52fb7d01a5ea7d1 100644 (file)
@@ -74,51 +74,6 @@ typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const gg
                                        const float *src1_dd, float *dst_dd,
                                        const queue_ptr &main_stream);
 
-static __dpct_inline__ float warp_reduce_sum(float x,
-                                             const sycl::nd_item<3> &item_ct1) {
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        /*
-        DPCT1096:98: The right-most dimension of the work-group used in the SYCL
-        kernel that calls this function may be less than "32". The function
-        "dpct::permute_sub_group_by_xor" may return an unexpected result on the
-        CPU device. Modify the size of the work-group to ensure that the value
-        of the right-most dimension is a multiple of "32".
-        */
-        x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
-    }
-    return x;
-}
-
-static __dpct_inline__ sycl::float2
-warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3> &item_ct1) {
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        a.x() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.x(),
-                                                mask);
-        a.y() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.y(),
-                                                mask);
-    }
-    return a;
-}
-
-static __dpct_inline__ float warp_reduce_max(float x,
-                                             const sycl::nd_item<3> &item_ct1) {
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        /*
-        DPCT1096:97: The right-most dimension of the work-group used in the SYCL
-        kernel that calls this function may be less than "32". The function
-        "dpct::permute_sub_group_by_xor" may return an unexpected result on the
-        CPU device. Modify the size of the work-group to ensure that the value
-        of the right-most dimension is a multiple of "32".
-        */
-        x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
-                              item_ct1.get_sub_group(), x, mask));
-    }
-    return x;
-}
-
 static __dpct_inline__ float op_repeat(const float a, const float b) {
     return b;
     GGML_UNUSED(a);
@@ -336,47 +291,6 @@ static void sqr_f32(const float * x, float * dst, const int k,
     dst[i] = x[i] * x[i];
 }
 
-static void norm_f32(const float * x, float * dst, const int ncols, const float eps,
-                     const sycl::nd_item<3> &item_ct1, sycl::float2 *s_sum, int block_size) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-    const int tid = item_ct1.get_local_id(2);
-
-    sycl::float2 mean_var = sycl::float2(0.f, 0.f);
-
-    for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row*ncols + col];
-        mean_var.x() += xi;
-        mean_var.y() += xi * xi;
-    }
-
-    // sum up partial sums
-    mean_var = warp_reduce_sum(mean_var, item_ct1);
-    if (block_size > WARP_SIZE) {
-
-        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
-        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = mean_var;
-        }
-        /*
-        DPCT1118:0: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-        mean_var = s_sum[lane_id];
-        mean_var = warp_reduce_sum(mean_var, item_ct1);
-    }
-
-    const float mean = mean_var.x() / ncols;
-    const float var = mean_var.y() / ncols - mean * mean;
-    const float inv_std = sycl::rsqrt(var + eps);
-
-    for (int col = tid; col < ncols; col += block_size) {
-        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
-    }
-}
-
 static void concat_f32(const float  *x,const float  *y, float *dst, const int ne0, const int ne02,
                        const sycl::nd_item<3> &item_ct1) {
     int nidx = item_ct1.get_local_id(2) +
@@ -444,126 +358,11 @@ static void pad_f32(const float  *x, float *dst, const int ne0, const int ne00,
     }
 }
 
-static void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps,
-                           const sycl::nd_item<3> &item_ct1, float *s_sum, int block_size) {
-    int start = item_ct1.get_group(2) * group_size;
-    int end = start + group_size;
-
-    start += item_ct1.get_local_id(2);
-
-    if (end >= ne_elements) {
-        end = ne_elements;
-    }
-
-    float tmp = 0.0f; // partial sum for thread in warp
-
-    for (int j = start; j < end; j += block_size) {
-        tmp += x[j];
-    }
-
-    tmp = warp_reduce_sum(tmp, item_ct1);
-    if (block_size > WARP_SIZE) {
-
-        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
-        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        /*
-        DPCT1118:1: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
-        /*
-        DPCT1065:54: Consider replacing sycl::nd_item::barrier() with
-        sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
-        better performance if there is no access to global memory.
-        */
-        item_ct1.barrier();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp, item_ct1);
-    }
-
-    float mean = tmp / group_size;
-    tmp = 0.0f;
-
-    for (int j = start; j < end; j += block_size) {
-        float xi = x[j] - mean;
-        dst[j] = xi;
-        tmp += xi * xi;
-    }
-
-    tmp = warp_reduce_sum(tmp, item_ct1);
-    if (block_size > WARP_SIZE) {
-
-        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
-        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        /*
-        DPCT1118:2: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
-        /*
-        DPCT1065:55: Consider replacing sycl::nd_item::barrier() with
-        sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
-        better performance if there is no access to global memory.
-        */
-        item_ct1.barrier();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp, item_ct1);
-    }
-
-    float variance = tmp / group_size;
-    float scale = sycl::rsqrt(variance + eps);
-    for (int j = start; j < end; j += block_size) {
-        dst[j] *= scale;
-    }
-}
-
-static void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps,
-                         const sycl::nd_item<3> &item_ct1, float *s_sum, int block_size) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-                    item_ct1.get_local_id(1);
-    const int tid = item_ct1.get_local_id(2);
-
-    float tmp = 0.0f; // partial sum for thread in warp
-
-    for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row*ncols + col];
-        tmp += xi * xi;
-    }
-
-    // sum up partial sums
-    tmp = warp_reduce_sum(tmp, item_ct1);
-    if (block_size > WARP_SIZE) {
-
-        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
-        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        /*
-        DPCT1118:3: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp, item_ct1);
-    }
-
-    const float mean = tmp / ncols;
-    const float scale = sycl::rsqrt(mean + eps);
-
-    for (int col = tid; col < ncols; col += block_size) {
-        dst[row*ncols + col] = scale * x[row*ncols + col];
-    }
-}
-
+template<int QUANT_BLOCK_TILE>
 static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
                           const sycl::nd_item<3> &item_ct1) {
-    const int ix = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                   item_ct1.get_local_id(2);
+    const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
 
     if (ix >= kx_padded) {
         return;
@@ -578,23 +377,39 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
 
     const int ib = i_padded / QK8_1; // block index
     const int iqs = i_padded % QK8_1; // quant index
-
-    const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
-    float amax = sycl::fabs((float)xi);
-    float sum = xi;
-
+    typedef  sycl::vec<float, QUANT_BLOCK_TILE> TC;
+    typedef  sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
+    TC zeros;
+    TQ qzeros;
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        amax = sycl::fmax(amax, dpct::permute_sub_group_by_xor(
-                                    item_ct1.get_sub_group(), amax, mask));
-        sum +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), sum, mask);
+    for (int i = 0; i < QUANT_BLOCK_TILE; i++)
+    {
+        zeros[i] = 0.f;
+        qzeros[i] = 0;
+    }
+    const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros;
+    float sum = xi[0];
+    float amax = sycl::fabs(xi[0]);
+#pragma unroll
+    for (int i = 1; i < QUANT_BLOCK_TILE; i++)
+    {
+        sum += xi[i];
+        amax = sycl::fmax(sycl::fabs(xi[i]), amax);
     }
+    sum = warp_reduce_sum(sum, item_ct1);
+    amax = warp_reduce_max(amax, item_ct1);
 
     const float d = amax / 127;
-    const int8_t q = amax == 0.0f ? 0 : sycl::round(xi / d);
+    TQ q = qzeros;
+    if (amax != 0.0f)
+    {
+#pragma unroll
+        for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
+            q[i] = sycl::round(xi[i] / d);
+        }
+    }
 
-    y[ib].qs[iqs] = q;
+    *(TQ *)&y[ib].qs[iqs] = q;
 
     if (iqs > 0) {
         return;
@@ -728,7 +543,7 @@ static void mul_mat_p021_f16_f32(
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
+    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -781,7 +596,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
+    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -1643,99 +1458,6 @@ static void sqr_f32_sycl(const float *x, float *dst, const int k,
         });
 }
 
-static void norm_f32_sycl(const float *x, float *dst, const int ncols,
-                          const int nrows, const float eps,
-                          queue_ptr stream) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
-    if (ncols < 1024) {
-        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
-                sycl::range<1>(32), cgh);
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        norm_f32(x, dst, ncols, eps, item_ct1,
-                                            s_sum_acc_ct1.get_pointer(), WARP_SIZE);
-                    });
-        });
-    } else {
-        const int work_group_size = get_work_group_size(stream->get_device());
-        const sycl::range<3> block_dims(1, 1, work_group_size);
-        /*
-        DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
-                sycl::range<1>(32), cgh);
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        norm_f32(x, dst, ncols, eps, item_ct1,
-                                       s_sum_acc_ct1.get_pointer(), work_group_size);
-                    });
-        });
-    }
-}
-
-static void group_norm_f32_sycl(const float *x, float *dst,
-                                const int num_groups, const int group_size,
-                                const int ne_elements, queue_ptr stream) {
-    static const float eps = 1e-6f;
-    if (group_size < 1024) {
-        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
-                                                         cgh);
-
-            const float eps_ct4 = eps;
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        group_norm_f32(
-                            x, dst, group_size, ne_elements, eps_ct4, item_ct1,
-                            s_sum_acc_ct1.get_pointer(), WARP_SIZE);
-                    });
-        });
-    } else {
-        const int work_group_size = get_work_group_size(stream->get_device());
-        const sycl::range<3> block_dims(1, 1, work_group_size);
-        /*
-        DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
-                                                         cgh);
-
-            const float eps_ct4 = eps;
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        group_norm_f32(x, dst, group_size, ne_elements,
-                                             eps_ct4, item_ct1,
-                                             s_sum_acc_ct1.get_pointer(), work_group_size);
-                    });
-        });
-    }
-}
-
 static void concat_f32_sycl(const float *x, const float *y, float *dst,
                             const int ne0, int ne1, int ne2, int ne02,
                             queue_ptr stream) {
@@ -1777,64 +1499,22 @@ static void pad_f32_sycl(const float *x, float *dst, const int ne00,
         });
 }
 
-static void rms_norm_f32_sycl(const float *x, float *dst, const int ncols,
-                              const int nrows, const float eps,
-                              queue_ptr stream) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
-    // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
-    if (ncols < 1024) {
-        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
-                                                         cgh);
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        rms_norm_f32(x, dst, ncols, eps, item_ct1,
-                                                s_sum_acc_ct1.get_pointer(), WARP_SIZE);
-                    });
-        });
-    } else {
-        const int work_group_size = get_work_group_size(stream->get_device());
-        const sycl::range<3> block_dims(1, 1, work_group_size);
-        /*
-        DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(32),
-                                                         cgh);
-
-            cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
-                                  block_dims),
-                [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(32)]] {
-                        rms_norm_f32(x, dst, ncols, eps, item_ct1,
-                                           s_sum_acc_ct1.get_pointer(), work_group_size);
-                    });
-        });
-    }
-}
-
 static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
                                    const int ky, const int kx_padded,
                                    queue_ptr stream) {
     const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
     const sycl::range<3> num_blocks(1, ky, block_num_x);
-    const sycl::range<3> block_size(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE);
+    int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
+    static_assert(QK8_1 % WARP_SIZE == 0);
+    const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
 
         stream->parallel_for(
             sycl::nd_range<3>(num_blocks * block_size, block_size),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-                quantize_q8_1(x, vy, kx, kx_padded, item_ct1);
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
             });
     }
 }
@@ -1854,7 +1534,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
 
         stream->parallel_for(
             sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
                 mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
                                      nchannels_y, item_ct1);
             });
@@ -1874,7 +1554,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
 
         stream->parallel_for(
             sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
                 mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
                                        row_stride_x, channel_stride_x,
                                        nchannels_y / nchannels_x, item_ct1);
@@ -2139,7 +1819,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
     const sycl::range<3> block_nums(1, nrows, 1);
     stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
                          [=](sycl::nd_item<3> item_ct1)
-                             [[intel::reqd_sub_group_size(32)]] {
+                             [[intel::reqd_sub_group_size(WARP_SIZE)]] {
                                  k_sum_rows_f32(x, dst, ncols, item_ct1);
                              });
 }
@@ -2220,7 +1900,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
 
         cgh.parallel_for(
             sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
                 soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
                                                                              nrows_y, scale, max_bias, m0,
                                                                              m1, n_head_log2, item_ct1,
@@ -2400,12 +2080,6 @@ static inline int get_sycl_env(const char *env_name, int default_val) {
     return user_number;
 }
 
-static inline int get_work_group_size(const sycl::device& device) {
-    dpct::device_info prop;
-    dpct::get_device_info(prop, device);
-    return prop.get_max_work_group_size();
-}
-
 static void ggml_check_sycl() try {
     static bool initialized = false;
 
@@ -2964,45 +2638,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor
     (void) src1_dd;
 }
 
-inline void ggml_sycl_op_norm(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                              ggml_tensor *dst, const float *src0_dd,
-                              const float *src1_dd, float *dst_dd,
-                              const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
-
-    float eps;
-    memcpy(&eps, dst->op_params, sizeof(float));
-
-    norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                    const ggml_tensor *src1, ggml_tensor *dst,
-                                    const float *src0_dd, const float *src1_dd,
-                                    float *dst_dd,
-                                    const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    int num_groups = dst->op_params[0];
-    int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
-    group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
 inline void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
                                 const ggml_tensor *src1, ggml_tensor *dst,
                                 const float *src0_dd, const float *src1_dd,
@@ -3066,28 +2701,6 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor
     (void) src1_dd;
 }
 
-inline void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                  const float *src0_dd, const float *src1_dd,
-                                  float *dst_dd,
-                                  const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
-
-    float eps;
-    memcpy(&eps, dst->op_params, sizeof(float));
-
-    rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
 static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split) {
     int64_t min_compute_capability = INT_MAX;
     int64_t max_compute_capability = INT_MIN;
@@ -4273,7 +3886,6 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
 
 static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
-
     int64_t min_compute_capability = INT_MAX;
 
     if (split) {