| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
-| PAD_REFLECT_1D | â\9d\8c | â\9c\85 | â\9c\85 | â\9d\8c | â\9c\85 | â\9d\8c | â\9d\8c | ❌ | ❌ |
+| PAD_REFLECT_1D | â\9d\8c | â\9c\85 | â\9c\85 | â\9d\8c | â\9c\85 | â\9d\8c | â\9c\85 | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
"SYCL0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0","support","1","yes","SYCL"
-"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","SYCL"
-"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","SYCL"
+"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","yes","SYCL"
+"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","yes","SYCL"
"SYCL0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","SYCL"
"SYCL0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","SYCL"
"SYCL0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","SYCL"
#include "softmax.hpp"
#include "tsembd.hpp"
#include "wkv.hpp"
+#include "pad_reflect_1d.hpp"
+
#endif // GGML_SYCL_BACKEND_HPP
case GGML_OP_CONCAT:
ggml_sycl_op_concat(ctx, dst);
break;
+ case GGML_OP_PAD_REFLECT_1D:
+ ggml_sycl_op_pad_reflect_1d(ctx,dst);
+ break;
case GGML_OP_UPSCALE:
ggml_sycl_upscale(ctx, dst);
break;
case GGML_OP_DIV:
case GGML_OP_REPEAT:
return true;
+ case GGML_OP_PAD_REFLECT_1D:
+ return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
--- /dev/null
+#include "pad_reflect_1d.hpp"
+
+void pad_reflect_1d_f32(const float* src,float* dst,
+ const int64_t ne0, const int64_t ne02, const int p0, const int p1,
+ const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
+ const sycl::nd_item<3> &item_ct1){
+
+ const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
+ const int i1 = item_ct1.get_group(1);
+ const int g2 = item_ct1.get_group(2);
+ const int i2 = g2 % ne02;
+ const int i3 = g2 / ne02;
+
+ if (i0 >= p0 + ne0 + p1) return;
+
+ int t = i0 - p0;
+ int period = 2 * ne0 -2;
+ int m = t % period;
+ m += (m < 0) * period;
+ int center = ne0 -1;
+ int srci0 = center - abs(center - m);
+
+ int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
+ int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
+ dst[offest_dst] = src[offest_src];
+
+}
+
+void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){
+
+ const ggml_tensor * src0 = dst->src[0];
+ queue_ptr stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int32_t * opts = (const int32_t *) dst->op_params;
+ const int p0 = opts[0];
+ const int p1 = opts[1];
+
+ const int64_t ne0 = src0->ne[0];
+
+ const int64_t ne00 = dst->ne[0];
+ const int64_t ne01 = dst->ne[1];
+ const int64_t ne02 = dst->ne[2];
+ const int64_t ne03 = dst->ne[3];
+
+ const int64_t nb00 = dst->nb[0];
+ const int64_t nb01 = dst->nb[1];
+ const int64_t nb02 = dst->nb[2];
+ const int64_t nb03 = dst->nb[3];
+ const int64_t nb0 = src0->nb[0];
+ const int64_t nb1 = src0->nb[1];
+ const int64_t nb2 = src0->nb[2];
+ const int64_t nb3 = src0->nb[3];
+
+ int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
+ sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
+ sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);
+
+ stream->parallel_for(
+ sycl::nd_range<3>(global,
+ local),
+ [=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32(
+ (const float *) src0->data, (float *) dst->data,
+ ne0, ne02, p0, p1,
+ nb0, nb1, nb2, nb3,
+ nb00, nb01, nb02, nb03
+ , item_ct1);
+ });
+}
--- /dev/null
+#ifndef GGML_SYCL_PAD_REFLECT_1D_HPP
+#define GGML_SYCL_PAD_REFLECT_1D_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
+
+#endif // GGML_SYCL_PAD_REFLECT_1D_HPP