From: Neo Zhang Jianyu Date: Sun, 12 Oct 2025 13:53:35 +0000 (+0800) Subject: fix UT fault cases: count-equal, argsort, pad OPs (llama/16521) X-Git-Tag: upstream/1.8.2~20 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=be778c992fdb6e3e37b8f48feb84350f07dd4209;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp fix UT fault cases: count-equal, argsort, pad OPs (llama/16521) * fix/refactor OP argsort, pad * fix count-equal op * update SYCL OP list * fix format issue --------- Co-authored-by: Zhang Jianyu --- diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 410a67b0..6ff3215d 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -18,6 +18,7 @@ #include "concat.hpp" #include "conv.hpp" #include "convert.hpp" +#include "count-equal.hpp" #include "cpy.hpp" #include "dequantize.hpp" #include "dmmv.hpp" @@ -28,6 +29,7 @@ #include "mmvq.hpp" #include "norm.hpp" #include "outprod.hpp" +#include "pad.hpp" #include "quantize.hpp" #include "quants.hpp" #include "rope.hpp" diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index e0a1de0f..0a3883ae 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -303,10 +303,6 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); } -inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); -} - inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); @@ -332,11 +328,6 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_sub(ctx, dst); } -void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); - ggml_sycl_op_count_equal(ctx, dst); -} - void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_mul(ctx, dst); diff --git a/ggml/src/ggml-sycl/binbcast.hpp b/ggml/src/ggml-sycl/binbcast.hpp index 34c4064f..9cce0f05 100644 --- a/ggml/src/ggml-sycl/binbcast.hpp +++ b/ggml/src/ggml-sycl/binbcast.hpp @@ -16,12 +16,6 @@ static __dpct_inline__ float op_sub(const float a, const float b) { return a - b; } -static __dpct_inline__ float op_count_equal(const float a, const float b) { - return (a == b) ? 1.0f : 0.0f; -} - -void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - static __dpct_inline__ float op_mul(const float a, const float b) { return a * b; } diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index d66d7ade..338fa08c 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -195,7 +195,8 @@ struct optimize_feature { struct sycl_device_info { int cc; // compute capability - // int nsm; // number of streaming multiprocessors + int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum + // number of compute units on a SYCL device. // size_t smpb; // max. shared memory per block size_t smpbo; // max. shared memory per block (with opt-in) bool vmm; // virtual memory support diff --git a/ggml/src/ggml-sycl/count-equal.cpp b/ggml/src/ggml-sycl/count-equal.cpp new file mode 100644 index 00000000..b0a8b482 --- /dev/null +++ b/ggml/src/ggml-sycl/count-equal.cpp @@ -0,0 +1,79 @@ +#include "count-equal.hpp" + +#include + +template +static void count_equal(const T *__restrict__ x, const T *__restrict__ y, + int64_t *__restrict__ dst, const int64_t dk, + const int64_t k) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i0 = (int64_t)item_ct1.get_group(2) * dk; + const int64_t i1 = sycl::min(i0 + dk, k); + + int nequal = 0; + + for (int64_t i = i0 + item_ct1.get_local_id(2); i < i1; i += WARP_SIZE) { + const T xi = x[i]; + const T yi = y[i]; + nequal += xi == yi; + } + + nequal = warp_reduce_sum(nequal); + + if (item_ct1.get_local_id(2) != 0) { + return; + } + + dpct::atomic_fetch_add( + (int *)dst, nequal); +} + +void ggml_sycl_count_equal(ggml_backend_sycl_context &ctx, ggml_tensor *dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT( dst->type == GGML_TYPE_I64); + + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + int64_t * dst_d = (int64_t *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + const int id = get_current_device_id(); + const int nsm = ggml_sycl_info().devices[id].nsm; + + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int"); + const int64_t dne = + GGML_PAD((ne + 4 * nsm - 1) / (4 * nsm), SYCL_COUNT_EQUAL_CHUNK_SIZE); + + SYCL_CHECK(CHECK_TRY_ERROR(stream->memset(dst_d, 0, ggml_nbytes(dst)))); + + const dpct::dim3 block_dims(WARP_SIZE, 1, 1); + const dpct::dim3 block_nums( + std::min((int64_t)4 * nsm, (ne + SYCL_COUNT_EQUAL_CHUNK_SIZE - 1) / + SYCL_COUNT_EQUAL_CHUNK_SIZE), + 1, 1); + + switch (src0->type) { + case GGML_TYPE_I32: { + const int *src0_d = (const int *)src0->data; + const int *src1_d = (const int *)src1->data; + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + count_equal(src0_d, src1_d, dst_d, dne, ne); + GGML_UNUSED(item_ct1); + }); + + } break; + default: + GGML_ASSERT(false); + break; + } +} diff --git a/ggml/src/ggml-sycl/count-equal.hpp b/ggml/src/ggml-sycl/count-equal.hpp new file mode 100644 index 00000000..f7f4fcbd --- /dev/null +++ b/ggml/src/ggml-sycl/count-equal.hpp @@ -0,0 +1,9 @@ +#ifndef GGML_SYCL_COUNT_EQUAL_HPP +#define GGML_SYCL_COUNT_EQUAL_HPP +#include "common.hpp" + +#define SYCL_COUNT_EQUAL_CHUNK_SIZE 128 + +void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif //GGML_SYCL_COUNT_EQUAL_HPP diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index c2da2fb4..aeeb3875 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -328,26 +328,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01, dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); } -template -static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02, - const sycl::nd_item<3> &item_ct1) { - int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2); - if (nidx >= ne0) { - return; - } - - // operation - int offset_dst = nidx + item_ct1.get_group(1) * ne0 + - item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); - if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) { - int offset_src = nidx + item_ct1.get_group(1) * ne00 + - item_ct1.get_group(0) * ne00 * ne01; - dst[offset_dst] = x[offset_src]; - } else { - dst[offset_dst] = static_cast(0.0f); - } -} - template static void clamp(const T * x, T * dst, const float min, const float max, const int k, const sycl::nd_item<1> &item_ct1) { @@ -431,18 +411,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, }); } -template -static void pad_sycl(const T *x, T *dst, const int ne00, - const int ne01, const int ne02, const int ne0, - const int ne1, const int ne2, queue_ptr stream) { - int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE); - sycl::range<3> gridDim(ne2, ne1, num_blocks); - stream->parallel_for( - sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); }); -} - template static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) @@ -596,40 +564,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx } } -template -static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0], - (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward(args)...); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0], - (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward(args)...); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} } // namespace ggml_sycl_detail @@ -919,14 +853,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te }); } -static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst, - [](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, - queue_ptr stream) { - ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream); - }); -} - static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { float min_val; float max_val; @@ -1119,10 +1045,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_upscale(ctx, dst); } -void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); - ggml_sycl_op_pad(ctx, dst); -} void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 50749e87..43474317 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -67,8 +67,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); -void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e4cc3c8e..45b8c216 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -85,9 +85,11 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].cc = 100 * prop.get_major_version() + 10 * prop.get_minor_version(); + info.devices[i].nsm = prop.get_max_compute_units(); info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); - info.max_work_group_sizes[i] = prop.get_max_work_group_size(); info.devices[i].smpbo = prop.get_local_mem_size(); + + info.max_work_group_sizes[i] = prop.get_max_work_group_size(); } for (int id = 0; id < info.device_count; ++id) { @@ -1512,60 +1514,70 @@ static inline void ggml_sycl_swap(T & a, T & b) { template __dpct_inline__ static void k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad, - const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) { + const int tasks_per_thread, const sycl::nd_item<3> &item_ct1, + uint8_t *dpct_local) { // bitonic sort - int col = item_ct1.get_local_id(2); + int col_index = item_ct1.get_local_id(2); int row = item_ct1.get_group(1); - if (col >= ncols_pad) { - return; + for (int i = 0; i < tasks_per_thread; i++) { + int col = col_index * tasks_per_thread + i; + if (col >= ncols_pad) { + return; + } } const float * x_row = x + row * ncols; auto dst_row = (int *)dpct_local; // initialize indices - dst_row[col] = col; + for (int i=0;i 0; j /= 2) { - int ixj = col ^ j; - if (ixj > col) { - if ((col & k) == 0) { - if (dst_row[col] >= ncols || - (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] > x_row[dst_row[ixj]] : - x_row[dst_row[col]] < x_row[dst_row[ixj]])) - ) { - ggml_sycl_swap(dst_row[col], dst_row[ixj]); - } - } else { - if (dst_row[ixj] >= ncols || - (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] < x_row[dst_row[ixj]] : - x_row[dst_row[col]] > x_row[dst_row[ixj]])) - ) { - ggml_sycl_swap(dst_row[col], dst_row[ixj]); + for (int i = 0; i < tasks_per_thread; i++) { + int col = col_index * tasks_per_thread + i; + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && + (order == GGML_SORT_ORDER_ASC + ? x_row[dst_row[col]] > x_row[dst_row[ixj]] + : x_row[dst_row[col]] < + x_row[dst_row[ixj]]))) { + ggml_sycl_swap(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && + (order == GGML_SORT_ORDER_ASC + ? x_row[dst_row[col]] < x_row[dst_row[ixj]] + : x_row[dst_row[col]] > + x_row[dst_row[ixj]]))) { + ggml_sycl_swap(dst_row[col], dst_row[ixj]); + } } } + item_ct1.barrier(sycl::access::fence_space::local_space); } - /* - DPCT1118:1: SYCL group functions and algorithms must be encountered - in converged control flow. You may need to adjust the code. - */ - item_ct1.barrier(sycl::access::fence_space::local_space); } } // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; + for (int i = 0; i < tasks_per_thread; i++) { + int col = col_index * tasks_per_thread + i; + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } } - static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past, const sycl::nd_item<3> &item_ct1) { const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) + @@ -1738,11 +1750,20 @@ static int next_power_of_2(int x) { static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, const int nrows, ggml_sort_order order, - queue_ptr stream) { + queue_ptr stream, int device) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); - const sycl::range<3> block_dims(1, 1, ncols_pad); + int nth = 1; + int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; + while (nth < ncols_pad && nth < max_block_size) + nth *= 2; + if (nth > max_block_size) + nth = max_block_size; + + const int tasks_per_thread = ncols_pad / nth; + + const sycl::range<3> block_dims(1, 1, nth); const sycl::range<3> block_nums(1, nrows, 1); const size_t shared_mem = ncols_pad * sizeof(int); @@ -1755,8 +1776,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { k_argsort_f32_i32( - x, dst, ncols, ncols_pad, item_ct1, - dpct_local_acc_ct1.get_multi_ptr() + x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1, + dpct_local_acc_ct1 + .get_multi_ptr() .get()); }); }); @@ -1769,8 +1791,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { k_argsort_f32_i32( - x, dst, ncols, ncols_pad, item_ct1, - dpct_local_acc_ct1.get_multi_ptr() + x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1, + dpct_local_acc_ct1 + .get_multi_ptr() .get()); }); }); @@ -2142,7 +2165,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream); + argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, + main_stream, ctx.device); } inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { @@ -4413,8 +4437,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ACC: return true; case GGML_OP_PAD: - return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && - (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); + return ggml_is_contiguous(op->src[0]); case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_RWKV_WKV6: diff --git a/ggml/src/ggml-sycl/pad.cpp b/ggml/src/ggml-sycl/pad.cpp new file mode 100644 index 00000000..413712c5 --- /dev/null +++ b/ggml/src/ggml-sycl/pad.cpp @@ -0,0 +1,97 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +//#include "common.hpp" +#include "pad.hpp" + +static void pad_f32(const float * src, float * dst, + const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, + const int ne0, const int ne1, const int ne2, const int ne3) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + int i0 = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + int i1 = item_ct1.get_group(1); + int i2 = item_ct1.get_group(0) % ne2; + int i3 = item_ct1.get_group(0) / ne2; + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + // operation + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + if ((i0 >= lp0 && i0 < ne0 - rp0) && + (i1 >= lp1 && i1 < ne1 - rp1) && + (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t i00 = i0 - lp0; + const int64_t i01 = i1 - lp1; + const int64_t i02 = i2 - lp2; + const int64_t i03 = i3 - lp3; + const int64_t ne02 = ne2 - lp2 - rp2; + const int64_t ne01 = ne1 - lp1 - rp1; + const int64_t ne00 = ne0 - lp0 - rp0; + + const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + + i02 * (ne00 * ne01) + i01 * ne00 + i00; + + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = 0.0f; + } +} + +static void pad_f32_sycl(const float *src, float *dst, const int lp0, + const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, + const int rp3, const int ne0, const int ne1, + const int ne2, const int ne3, + dpct::queue_ptr stream) { + int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; + dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3); + stream->parallel_for( + sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, + ne2, ne3); + }); +} + +void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; + const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; + const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; + const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; + const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; + + pad_f32_sycl(src0_d, dst_d, + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); +} + +void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_pad(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/pad.hpp b/ggml/src/ggml-sycl/pad.hpp new file mode 100644 index 00000000..b099e9b7 --- /dev/null +++ b/ggml/src/ggml-sycl/pad.hpp @@ -0,0 +1,24 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_PAD_HPP +#define GGML_SYCL_PAD_HPP + +#include "common.hpp" + +#define SYCL_PAD_BLOCK_SIZE 256 + +void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_PAD_HPP