]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Fix im2col with 32fp (#5286)
authorAidanBeltonS <redacted>
Sat, 3 Feb 2024 08:11:37 +0000 (08:11 +0000)
committerGitHub <redacted>
Sat, 3 Feb 2024 08:11:37 +0000 (16:11 +0800)
ggml-sycl.cpp

index ac75f8e16f63e1cc13ac5dbda98f766da93ebf9b..51445b5e7eaaf5cf40a5b1e32e99339bd4f3ff9e 100644 (file)
@@ -8247,7 +8247,8 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
     dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
 }
 
-static void im2col_f32_f16(const float *x, sycl::half *dst, int offset_delta,
+template <typename T>
+static void im2col_kernel(const float *x, T *dst, int offset_delta,
                            int IW, int IH, int OW, int KW, int KH,
                            int pelements, int CHW, int s0, int s1, int p0,
                            int p1, int d0, int d1,
@@ -11019,7 +11020,8 @@ static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
     });
 }
 
-static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
+template <typename T>
+static void im2col_sycl(const float *x, T *dst, int IW, int IH,
                                 int OW, int OH, int KW, int KH, int IC,
                                 int offset_delta, int s0, int s1, int p0,
                                 int p1, int d0, int d1,
@@ -11036,7 +11038,7 @@ static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
                                   sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
                               sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
             [=](sycl::nd_item<3> item_ct1) {
-                im2col_f32_f16(x, dst, offset_delta, IW, IH, OW, KW, KH,
+                im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
                                parallel_elements, (IC * KH * KW), s0, s1, p0,
                                p1, d0, d1, item_ct1);
             });
@@ -12424,7 +12426,7 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
 
     const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
     const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@@ -12447,8 +12449,11 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
 
     const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
 
-    im2col_f32_f16_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH,
-                        IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+    if (dst->type == GGML_TYPE_F16) {
+        im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+    } else {
+        im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+    }
 
     (void) src0;
     (void) src0_dd;