]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Fix SYCL `im2col` and `convert` Overflow with Large Dims (llama/9052)
authorzhentaoyu <redacted>
Tue, 20 Aug 2024 15:06:51 +0000 (23:06 +0800)
committerGeorgi Gerganov <redacted>
Tue, 27 Aug 2024 19:01:14 +0000 (22:01 +0300)
* sycl: fix im2col overflow and sync with cuda

Signed-off-by: zhentaoyu <redacted>
* sycl: fix convert overflow

Signed-off-by: zhentaoyu <redacted>
* sycl: fix convert and dequantize

Signed-off-by: zhentaoyu <redacted>
* sycl: fix ib in dmmv

Signed-off-by: zhentaoyu <redacted>
* sycl:refine convert

Signed-off-by: zhentaoyu <redacted>
* sycl: move downsample global_range into common

Signed-off-by: zhentaoyu <redacted>
* test: add im2col and convert test cases

Signed-off-by: zhentaoyu <redacted>
* test: make new cases only in sycl

Signed-off-by: zhentaoyu <redacted>
* test: comment new test_cases for only local testing

Signed-off-by: zhentaoyu <redacted>
---------

Signed-off-by: zhentaoyu <redacted>
src/ggml-sycl.cpp
src/ggml-sycl/backend.hpp
src/ggml-sycl/common.cpp
src/ggml-sycl/common.hpp
src/ggml-sycl/convert.cpp
src/ggml-sycl/convert.hpp
src/ggml-sycl/dequantize.hpp
src/ggml-sycl/dmmv.cpp
src/ggml-sycl/im2col.cpp [new file with mode: 0644]
src/ggml-sycl/im2col.hpp [new file with mode: 0644]
tests/test-backend-ops.cpp

index d8eb86c2c1862b45c914f8c213a4e12bcb2348c2..165b5a6df36f6946efa6c34e8f4627bd27acffbe 100644 (file)
@@ -893,43 +893,6 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
     dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
 }
 
-template <typename T>
-static void im2col_kernel(const float *x, T *dst, int offset_delta,
-                           int IW, int IH, int OW, int KW, int KH,
-                           int pelements, int CHW, int s0, int s1, int p0,
-                           int p1, int d0, int d1,
-                           const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_id(2) +
-                  item_ct1.get_group(2) * item_ct1.get_local_range(2);
-    if (i >= pelements) {
-        return;
-    }
-
-    const int ksize = OW * (KH > 1 ? KW : 1);
-    const int kx = i / ksize;
-    const int kd = kx * ksize;
-    const int ky = (i - kd) / OW;
-    const int ix = i % OW;
-
-    const int64_t iiw = ix * s0 + kx * d0 - p0;
-    const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1;
-
-    const int64_t offset_dst =
-        (item_ct1.get_group(1) * OW + ix) * CHW +
-        (item_ct1.get_group(0) * (KW * KH) + ky * KW + kx);
-
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-        dst[offset_dst] =
-            sycl::vec<float, 1>(0.0f)
-                .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
-    } else {
-        const int64_t offset_src = item_ct1.get_group(0) * offset_delta;
-        dst[offset_dst] =
-            sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
-                .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
-    }
-}
-
 template <typename Ti, typename To>
 static  void pool2d_nchw_kernel(
         const int ih, const int iw, const int oh, const int ow,
@@ -1742,32 +1705,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
                          });
 }
 
-template <typename T>
-static void im2col_sycl(const float *x, T *dst, int IW, int IH,
-                                int OW, int OH, int KW, int KH, int IC,
-                                int offset_delta, int s0, int s1, int p0,
-                                int p1, int d0, int d1,
-                                queue_ptr stream) {
-    const int parallel_elements = OW * KW * KH;
-    const int num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
-    sycl::range<3> block_nums(IC, OH, num_blocks);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums *
-                                  sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
-                               parallel_elements, (IC * KH * KW), s0, s1, p0,
-                               p1, d0, d1, item_ct1);
-            });
-    }
-}
-
-
 static bool g_sycl_loaded = false;
 
 bool ggml_sycl_loaded(void) {
@@ -2636,47 +2573,6 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
     (void) src1_dd;
 }
 
-inline void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst,
-                                const float *src0_dd, const float *src1_dd,
-                                float *dst_dd,
-                                const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
-
-    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
-    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
-    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
-    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
-    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
-    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
-
-    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
-
-    const int64_t IC = src1->ne[is_2D ? 2 : 1];
-    const int64_t IH = is_2D ? src1->ne[1] : 1;
-    const int64_t IW =         src1->ne[0];
-
-    const int64_t KH = is_2D ? src0->ne[1] : 1;
-    const int64_t KW =         src0->ne[0];
-
-    const int64_t OH = is_2D ? dst->ne[2] : 1;
-    const int64_t OW =         dst->ne[1];
-
-    const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
-
-    if (dst->type == GGML_TYPE_F16) {
-        im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
-    } else {
-        im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
-    }
-
-    (void) src0;
-    (void) src0_dd;
-}
-
 inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
                                   const ggml_tensor *src1, ggml_tensor *dst,
                                   const float *src0_dd, const float *src1_dd,
index 58dd9c9a60e7d72339d16d9dd6bf120bd69e27f7..d21b5f8dd2627537b9fb9013d1e87695f2646673 100644 (file)
@@ -25,5 +25,6 @@
 #include "norm.hpp"
 #include "softmax.hpp"
 #include "tsembd.hpp"
+#include "im2col.hpp"
 
 #endif // GGML_SYCL_BACKEND_HPP
index e878f4f50f09e2a4a26c18de858cd8db8303d88d..cf5291b31fe9172029e533f5381a0d9f601e1ee2 100644 (file)
@@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try {
             << ", line:" << __LINE__ << std::endl;
   std::exit(1);
 }
+
+int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
+  const int64_t max_range = std::numeric_limits<int>::max();
+  int64_t sycl_down_blk_size = block_size;
+  int64_t global_range = accumulate_block_num * sycl_down_blk_size;
+  while(global_range > max_range) {
+      sycl_down_blk_size /= 2;
+      global_range = accumulate_block_num * sycl_down_blk_size;
+  }
+  return sycl_down_blk_size;
+}
index 86d8b40e8b01333a7a24bed2e8b950cd5fab5c60..1dbf437090606585e00960f140de9914abeb6efe 100644 (file)
@@ -352,4 +352,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
     return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
 }
 
+int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
+
 #endif // GGML_SYCL_COMMON_HPP
index 39c28753cf85c29467b2514cd7f63617e2906a19..5fd15e6cdccabb277eff3b3b3171cbce976712c8 100644 (file)
@@ -3,19 +3,19 @@
 #include "presets.hpp"
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
+static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
                              const sycl::nd_item<3> &item_ct1) {
-    const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+    const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
                        item_ct1.get_local_id(2));
 
     if (i >= k) {
         return;
     }
 
-    const int ib = i/qk; // block index
-    const int iqs = (i%qk)/qr; // quant index
-    const int iybs = i - i%qk; // y block start index
-    const int y_offset = qr == 1 ? 1 : qk/2;
+    const int64_t ib = i/qk; // block index
+    const int64_t iqs = (i%qk)/qr; // quant index
+    const int64_t iybs = i - i%qk; // y block start index
+    const int64_t y_offset = qr == 1 ? 1 : qk/2;
 
     // dequantize
     dfloat2 v;
@@ -27,9 +27,9 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static void dequantize_block_sycl(const void *__restrict__ vx,
-                                  dst_t *__restrict__ y, const int k,
+                                  dst_t *__restrict__ y, const int64_t k,
                                   dpct::queue_ptr stream) {
-    const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
+    const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -45,9 +45,9 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
 }
 
 template <typename dst_t>
-static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
                                      dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 #if QK_K == 256
     {
         dpct::has_capability_or_fail(stream->get_device(),
@@ -77,9 +77,9 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
                                      dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 #if QK_K == 256
     {
         dpct::has_capability_or_fail(stream->get_device(),
@@ -108,10 +108,10 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
                                      dpct::queue_ptr stream) {
-    const int nb32 = k / 32;
-    const int nb = (k + 255) / 256;
+    const int64_t nb32 = k / 32;
+    const int64_t nb = (k + 255) / 256;
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -126,10 +126,10 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
                                      dpct::queue_ptr stream) {
-    const int nb32 = k / 32;
-    const int nb = (k + 255) / 256;
+    const int64_t nb32 = k / 32;
+    const int64_t nb = (k + 255) / 256;
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -145,9 +145,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
 
 
 template <typename dst_t>
-static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
                                      dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -165,9 +165,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
                                      dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 #if QK_K == 256
     {
         dpct::has_capability_or_fail(stream->get_device(),
@@ -197,9 +197,9 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
                                      dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
 #if QK_K == 256
     {
         dpct::has_capability_or_fail(stream->get_device(),
@@ -229,9 +229,9 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
                                         dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -250,9 +250,9 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
                                         dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -271,9 +271,9 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
                                         dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -292,9 +292,9 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
                                        dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -313,9 +313,9 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
                                       dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -333,9 +333,9 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
 
 
 template <typename dst_t>
-static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
                                         dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -354,9 +354,9 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
                                         dpct::queue_ptr stream) {
-    const int nb = k / QK_K;
+    const int64_t nb = k / QK_K;
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
@@ -374,9 +374,9 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
                                        dpct::queue_ptr stream) {
-    const int nb = (k + QK_K - 1) / QK_K;
+    const int64_t nb = (k + QK_K - 1) / QK_K;
 #if QK_K == 64
     dequantize_row_iq4_nl_sycl(vx, y, k, stream);
 #else
@@ -398,9 +398,9 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename dst_t>
-static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
+static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
                                        dpct::queue_ptr stream) {
-    const int nb = (k + QK_K - 1) / QK_K;
+    const int64_t nb = (k + QK_K - 1) / QK_K;
       {
             dpct::has_capability_or_fail(stream->get_device(),
                                          {sycl::aspect::fp16});
@@ -418,34 +418,34 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
 }
 
 template <typename src_t, typename dst_t>
-static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
+static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
                           const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
+    const int64_t work_group_size = item_ct1.get_local_range(2);
+    const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
 
+    // make each work-item deal with more elements since sycl global range can not exceed max int
     const src_t * x = (src_t *) vx;
-
-    y[i] = x[i];
+    for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
+        y[i] = x[i];
+    }
 }
 
 template <typename src_t, typename dst_t>
 static void convert_unary_sycl(const void *__restrict__ vx,
-                               dst_t *__restrict__ y, const int k,
+                               dst_t *__restrict__ y, const int64_t k,
                                dpct::queue_ptr stream) {
-    const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
+    const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
+
+    // decrease global range when it exceeds the max int
+    int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
+    sycl::range<3> block_nums(1, 1, num_blocks);
+    sycl::range<3> local_range(1, 1, local_size);
     {
         dpct::has_capability_or_fail(stream->get_device(),
                                      {sycl::aspect::fp16});
 
         stream->parallel_for(
-            sycl::nd_range<3>(
-                sycl::range<3>(1, 1, num_blocks) *
-                    sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
-                sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
+            sycl::nd_range<3>(block_nums * local_range, local_range),
             [=](sycl::nd_item<3> item_ct1) {
                 convert_unary<src_t>(vx, y, k, item_ct1);
             });
index b1f10d635553540e9d01e709ea7dac7bca9704a2..0ce2874aaaef9cb519211b0343e54fcb7dd06ec6 100644 (file)
@@ -17,7 +17,7 @@
 
 template <typename T>
 using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
-                             int k, dpct::queue_ptr stream);
+                             int64_t k, dpct::queue_ptr stream);
 typedef to_t_sycl_t<float> to_fp32_sycl_t;
 typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
 
index ed8ad098bcb2fdec53c2bf79a3fbaa64a1311ee4..8f4041fffce33564a354fd1195880f570f5276d1 100644 (file)
@@ -15,9 +15,9 @@
 
 #include "common.hpp"
 
-typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
+typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
 
-static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
+static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
                                             const int iqs, dfloat2 &v) {
     const block_q4_0 * x = (const block_q4_0 *) vx;
 
@@ -40,7 +40,7 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
 #endif // GGML_SYCL_F16
 }
 
-static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
+static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
                                             const int iqs, dfloat2 &v) {
     const block_q4_1 * x = (const block_q4_1 *) vx;
 
@@ -64,7 +64,7 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
 #endif // GGML_SYCL_F16
 }
 
-static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
+static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib,
                                             const int iqs, dfloat2 &v) {
     const block_q5_0 * x = (const block_q5_0 *) vx;
 
@@ -91,7 +91,7 @@ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
 #endif // GGML_SYCL_F16
 }
 
-static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
+static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
                                             const int iqs, dfloat2 &v) {
     const block_q5_1 * x = (const block_q5_1 *) vx;
 
@@ -118,7 +118,7 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
 #endif // GGML_SYCL_F16
 }
 
-static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
+static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,
                                             const int iqs, dfloat2 &v) {
     const block_q8_0 * x = (const block_q8_0 *) vx;
 
@@ -138,16 +138,16 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
 }
 
 template<typename dst_t>
-static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
+static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
                                   const sycl::nd_item<3> &item_ct1) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
 
     // assume 32 threads
-    const int tid = item_ct1.get_local_id(2);
-    const int il  = tid/8;
-    const int ir  = tid%8;
-    const int ib = 8*i + ir;
+    const int64_t tid = item_ct1.get_local_id(2);
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
+    const int64_t ib = 8*i + ir;
     if (ib >= nb32) {
         return;
     }
@@ -168,16 +168,16 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
 }
 
 template<typename dst_t>
-static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
+static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
                                   const sycl::nd_item<3> &item_ct1) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
 
     // assume 32 threads
-    const int tid = item_ct1.get_local_id(2);
-    const int il  = tid/8;
-    const int ir  = tid%8;
-    const int ib = 8*i + ir;
+    const int64_t tid = item_ct1.get_local_id(2);
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
+    const int64_t ib = 8*i + ir;
     if (ib >= nb32) {
         return;
     }
@@ -203,14 +203,14 @@ template<typename dst_t>
 static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
                                   const sycl::nd_item<3> &item_ct1) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_q2_K * x = (const block_q2_K *) vx;
 
-    const int tid = item_ct1.get_local_id(2);
+    const int64_t tid = item_ct1.get_local_id(2);
 #if QK_K == 256
-    const int n   = tid/32;
-    const int l   = tid - 32*n;
-    const int is  = 8*n + l/16;
+    const int64_t n   = tid/32;
+    const int64_t l   = tid - 32*n;
+    const int64_t is  = 8*n + l/16;
 
     const uint8_t q = x[i].qs[32*n + l];
     dst_t * y = yy + i*QK_K + 128*n;
@@ -222,8 +222,8 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
     y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
     y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
 #else
-    const int is = tid/16;  // 0 or 1
-    const int il = tid%16;  // 0...15
+    const int64_t is = tid/16;  // 0 or 1
+    const int64_t il = tid%16;  // 0...15
     const uint8_t q = x[i].qs[il] >> (2*is);
     dst_t * y = yy + i*QK_K + 16*is + il;
 
@@ -239,19 +239,19 @@ template<typename dst_t>
 static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
                                   const sycl::nd_item<3> &item_ct1) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_q3_K * x = (const block_q3_K *) vx;
 
 #if QK_K == 256
-    const int r = item_ct1.get_local_id(2) / 4;
-    const int tid = r/2;
-    const int is0 = r%2;
-    const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
-    const int n = tid / 4;
-    const int j = tid - 4*n;
+    const int64_t r = item_ct1.get_local_id(2) / 4;
+    const int64_t tid = r/2;
+    const int64_t is0 = r%2;
+    const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
+    const int64_t n = tid / 4;
+    const int64_t j = tid - 4*n;
 
     uint8_t m = 1 << (4*n + j);
-    int is = 8*n + 2*j + is0;
+    int64_t is = 8*n + 2*j + is0;
     int shift = 2*j;
 
     int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
@@ -267,11 +267,11 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
 
     for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
 #else
-    const int tid = item_ct1.get_local_id(2);
-    const int is  = tid/16;  // 0 or 1
-    const int il  = tid%16;  // 0...15
-    const int im  = il/8;    // 0...1
-    const int in  = il%8;    // 0...7
+    const int64_t tid = item_ct1.get_local_id(2);
+    const int64_t is  = tid/16;  // 0 or 1
+    const int64_t il  = tid%16;  // 0...15
+    const int64_t im  = il/8;    // 0...1
+    const int64_t in  = il%8;    // 0...7
 
     dst_t * y = yy + i*QK_K + 16*is + il;
 
@@ -307,15 +307,15 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
                                   uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
     const block_q4_K * x = (const block_q4_K *) vx;
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
 
 #if QK_K == 256
     // assume 32 threads
-    const int tid = item_ct1.get_local_id(2);
-    const int il  = tid/8;
-    const int ir  = tid%8;
-    const int is  = 2*il;
-    const int n   = 4;
+    const int64_t tid = item_ct1.get_local_id(2);
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
+    const int64_t is  = 2*il;
+    const int64_t n   = 4;
 
     dst_t * y = yy + i*QK_K + 64*il + n*ir;
 
@@ -341,7 +341,7 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
         y[l +32] = d2 * (q_vec[l] >>  4) - m2;
     }
 #else
-    const int tid = item_ct1.get_local_id(2);
+    const int64_t tid = item_ct1.get_local_id(2);
     const uint8_t * q = x[i].qs;
     dst_t * y = yy + i*QK_K;
     const float d = (float)x[i].dm[0];
@@ -356,14 +356,14 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
                                   const sycl::nd_item<3> &item_ct1) {
     const block_q5_K * x = (const block_q5_K *) vx;
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
 
 #if QK_K == 256
     // assume 64 threads - this is very slightly better than the one below
-    const int tid = item_ct1.get_local_id(2);
-    const int il  = tid/16;   // il is in 0...3
-    const int ir  = tid%16;   // ir is in 0...15
-    const int is  = 2*il;     // is is in 0...6
+    const int64_t tid = item_ct1.get_local_id(2);
+    const int64_t il  = tid/16;   // il is in 0...3
+    const int64_t ir  = tid%16;   // ir is in 0...15
+    const int64_t is  = 2*il;     // is is in 0...6
 
     dst_t * y = yy + i*QK_K + 64*il + 2*ir;
 
@@ -386,11 +386,11 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
     y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
     y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
 #else
-    const int tid = item_ct1.get_local_id(2);
+    const int64_t tid = item_ct1.get_local_id(2);
     const uint8_t q = x[i].qs[tid];
-    const int im = tid/8;  // 0...3
-    const int in = tid%8;  // 0...7
-    const int is = tid/16; // 0 or 1
+    const int64_t im = tid/8;  // 0...3
+    const int64_t in = tid%8;  // 0...7
+    const int64_t is = tid/16; // 0 or 1
     const uint8_t h = x[i].qh[in] >> im;
     const float d = x[i].d;
     dst_t * y = yy + i*QK_K + tid;
@@ -404,14 +404,14 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
                                   const sycl::nd_item<3> &item_ct1) {
     const block_q6_K * x = (const block_q6_K *) vx;
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
 #if QK_K == 256
 
     // assume 64 threads - this is very slightly better than the one below
-    const int tid = item_ct1.get_local_id(2);
-    const int ip  = tid/32;   // ip is 0 or 1
-    const int il  = tid - 32*ip; // 0...32
-    const int is  = 8*ip + il/16;
+    const int64_t tid = item_ct1.get_local_id(2);
+    const int64_t ip  = tid/32;   // ip is 0 or 1
+    const int64_t il  = tid - 32*ip; // 0...32
+    const int64_t is  = 8*ip + il/16;
 
     dst_t * y = yy + i*QK_K + 128*ip + il;
 
@@ -428,9 +428,9 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
 #else
 
     // assume 32 threads
-    const int tid = item_ct1.get_local_id(2);
-    const int ip  = tid/16;         // 0 or 1
-    const int il  = tid - 16*ip;    // 0...15
+    const int64_t tid = item_ct1.get_local_id(2);
+    const int64_t ip  = tid/16;         // 0 or 1
+    const int64_t il  = tid - 16*ip;    // 0...15
 
     dst_t * y = yy + i*QK_K + 16*ip + il;
 
@@ -452,13 +452,13 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
                                      const uint8_t *ksigns_iq2xs_ptr,
                                      const uint8_t *kmask_iq2xs_ptr) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;
 
-    const int tid = item_ct1.get_local_id(2);
+    const int64_t tid = item_ct1.get_local_id(2);
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint16_t * q2 = x[i].qs + 4*ib;
     const uint8_t  * aux8 = (const uint8_t *)q2;
@@ -480,13 +480,13 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
                                     const uint8_t *ksigns_iq2xs,
                                     const uint8_t *kmask_iq2xs) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_iq2_xs * x = (const block_iq2_xs *) vx;
 
-    const int tid = item_ct1.get_local_id(2);
+    const int64_t tid = item_ct1.get_local_id(2);
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint16_t * q2 = x[i].qs + 4*ib;
     const uint8_t  * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
@@ -504,13 +504,13 @@ __dpct_inline__ static void
 dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
                        const sycl::nd_item<3> &item_ct1) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_iq2_s * x = (const block_iq2_s *) vx;
 
-    const int tid = item_ct1.get_local_id(2);
+    const int64_t tid = item_ct1.get_local_id(2);
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
     const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
@@ -532,13 +532,13 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
                                      const uint8_t *ksigns_iq2xs,
                                      const uint8_t *kmask_iq2xs) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;
 
-    const int tid = item_ct1.get_local_id(2);
+    const int64_t tid = item_ct1.get_local_id(2);
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint8_t  * q3 = x[i].qs + 8*ib;
     const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
@@ -563,13 +563,13 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
                        const sycl::nd_item<3> &item_ct1,
                        const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_iq3_s * x = (const block_iq3_s *) vx;
 
-    const int tid = item_ct1.get_local_id(2);
+    const int64_t tid = item_ct1.get_local_id(2);
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint8_t * qs = x[i].qs + 8*ib;
     const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
@@ -593,13 +593,13 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
                        const sycl::nd_item<3> &item_ct1,
                        const uint32_t *iq1s_grid_gpu) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_iq1_s * x = (const block_iq1_s  *) vx;
 
-    const int tid = item_ct1.get_local_id(2);
+    const int64_t tid = item_ct1.get_local_id(2);
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
     const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
@@ -623,13 +623,13 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
                        const sycl::nd_item<3> &item_ct1,
                        const uint32_t *iq1s_grid_gpu) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_iq1_m * x = (const block_iq1_m  *) vx;
 
-    const int tid = item_ct1.get_local_id(2);
+    const int64_t tid = item_ct1.get_local_id(2);
 #if QK_K == 256
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
     const uint16_t * sc = (const uint16_t *)x[i].scales;
     iq1m_scale_t scale;
@@ -656,12 +656,12 @@ __dpct_inline__ static void
 dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
                         const sycl::nd_item<3> &item_ct1) {
 
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
 
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t tid = item_ct1.get_local_id(2);
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 4*il;
     const uint8_t  * q4 = x[ib].qs + 4*il;
     const float d = (float)x[ib].d;
@@ -678,12 +678,12 @@ template <typename dst_t>
 __dpct_inline__ static void
 dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
                         const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_group(2);
+    const int64_t i = item_ct1.get_group(2);
     const block_iq4_xs * x = (const block_iq4_xs *)vx;
 
-    const int tid = item_ct1.get_local_id(2);
-    const int il = tid/8; // 0...3
-    const int ib = tid%8; // 0...7
+    const int64_t tid = item_ct1.get_local_id(2);
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 4*il;
     const uint8_t  * q4 = x[i].qs + 16*ib + 4*il;
     const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
index ae45630e1173d3fcd4a8ac2c87ed22187d8e766c..5c343822f390f7fcb5e334ee36be21a33bfe2995 100644 (file)
@@ -4,7 +4,7 @@
 #include "presets.hpp"
 
 
-static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const sycl::half *x = (const sycl::half *)vx;
 
     // automatic half -> float type cast if dfloat == float
@@ -12,7 +12,7 @@ static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 &
     v.y() = x[ib + iqs + 1];
 }
 
-static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
+static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const float * x = (const float *) vx;
 
     // automatic half -> float type cast if dfloat == float
diff --git a/src/ggml-sycl/im2col.cpp b/src/ggml-sycl/im2col.cpp
new file mode 100644 (file)
index 0000000..6a0a0fc
--- /dev/null
@@ -0,0 +1,125 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#include "im2col.hpp"
+
+template <typename T>
+static void im2col_kernel(
+        const float *x, T *dst, int64_t batch_offset, int64_t offset_delta,
+        int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
+        int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
+        const sycl::nd_item<3> &item_ct1) {
+    const int64_t work_group_size = item_ct1.get_local_range(2);
+    const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
+
+    // make each work-item deal with more elements since sycl global range can not exceed max int
+    for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
+
+        const int64_t ksize = OW * (KH > 1 ? KW : 1);
+        const int64_t kx = i / ksize;
+        const int64_t kd = kx * ksize;
+        const int64_t ky = (i - kd) / OW;
+        const int64_t ix = i % OW;
+
+        const int64_t  oh = item_ct1.get_group(1);
+        const int64_t  batch = item_ct1.get_group(0) / IC;
+        const int64_t  ic = item_ct1.get_group(0) % IC;
+
+        const int64_t iiw = ix * s0 + kx * d0 - p0;
+        const int64_t iih = oh * s1 + ky * d1 - p1;
+
+        const int64_t offset_dst =
+            ((batch * OH + oh) * OW + ix) * CHW +
+            (ic * (KW * KH) + ky * KW + kx);
+
+        if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+            dst[offset_dst] =
+                sycl::vec<float, 1>(0.0f)
+                    .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+        } else {
+            const int64_t offset_src = ic * offset_delta + batch * batch_offset;
+            dst[offset_dst] =
+                sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
+                    .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+        }
+    }
+}
+
+template <typename T>
+static void im2col_sycl(
+        const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
+        int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
+        int s0, int s1, int p0, int p1, int d0, int d1,
+        queue_ptr stream) {
+    const int64_t parallel_elements = OW * KW * KH;
+    const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
+
+    // decrease global range when it exceeds the max int
+    int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
+    sycl::range<3> block_nums(batch * IC, OH, num_blocks);
+    sycl::range<3> local_range(1, 1, local_size);
+
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * local_range, local_range),
+            [=](sycl::nd_item<3> item_ct1) {
+                im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH,
+                               parallel_elements, (IC * KH * KW), s0, s1, p0,
+                               p1, d0, d1, item_ct1);
+            });
+    }
+}
+
+void ggml_sycl_op_im2col(
+        ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+        ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
+        const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+    const int64_t IC = src1->ne[is_2D ? 2 : 1];
+    const int64_t IH = is_2D ? src1->ne[1] : 1;
+    const int64_t IW =         src1->ne[0];
+
+    const int64_t KH = is_2D ? src0->ne[1] : 1;
+    const int64_t KW =         src0->ne[0];
+
+    const int64_t OH = is_2D ? dst->ne[2] : 1;
+    const int64_t OW =         dst->ne[1];
+
+    const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const int64_t batch = src1->ne[3];
+    const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
+
+    if (dst->type == GGML_TYPE_F16) {
+        im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+    } else {
+        im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+    }
+
+    (void) src0;
+    (void) src0_dd;
+}
diff --git a/src/ggml-sycl/im2col.hpp b/src/ggml-sycl/im2col.hpp
new file mode 100644 (file)
index 0000000..7db144f
--- /dev/null
@@ -0,0 +1,23 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_IM2COL_HPP
+#define GGML_SYCL_IM2COL_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_im2col(
+        ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+        ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
+        const queue_ptr &main_stream);
+
+#endif // GGML_SYCL_IM2COL_HPP
index 8f60854b2d34e6bf7fc2e10a48e485f33c8cd1cc..95037e77c180acf6acbc751242540b5f7ccd8a66 100644 (file)
@@ -2218,6 +2218,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
 
+    // sycl backend will limit task global_range < MAX_INT
+    // test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
+    // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)
+    // these cases are verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
+    // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
+    // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
+
     test_cases.emplace_back(new test_conv_transpose_1d());
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
@@ -2360,6 +2367,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 45, 128, { 8,  1}, {4, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45,  64, { 8,  1}, {4, 1}));
 
+    // sycl backend will limit task global_range < MAX_INT
+    // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
+    // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.)
+    // this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
+    // test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1}));
+
     for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
             for (int n_mats : {4, 8}) {