"SYCL0","CUMSUM","type=f32,ne=[375960,1,1,1]","support","0","no","SYCL"
"SYCL0","CUMSUM","type=f32,ne=[20481,4,1,1]","support","0","no","SYCL"
"SYCL0","XIELU","type=f32,ne=[10,5,4,3]","support","0","no","SYCL"
-"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","SYCL"
-"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","SYCL"
-"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","SYCL"
-"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","SYCL"
+"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","SYCL"
+"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","SYCL"
+"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","SYCL"
+"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","SYCL"
"SYCL0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","SYCL"
"SYCL0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","SYCL"
"SYCL0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","SYCL"
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
}
+static void tri_f32_sycl(
+ const float * src,
+ float * dst,
+ const int64_t ne0,
+ const int64_t ne1,
+ const int64_t ne2,
+ const int64_t ne3,
+ const ggml_tri_type ttype,
+ dpct::queue_ptr main_stream
+) {
+ const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
+
+ main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
+ const int64_t idx = (int64_t) tid[0];
+
+ const int64_t i0 = idx % ne0;
+ const int64_t t1 = idx / ne0;
+ const int64_t i1 = t1 % ne1;
+
+ bool keep = false;
+ switch (ttype) {
+ case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break;
+ case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
+ case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break;
+ case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
+ default: keep = false; break;
+ }
+
+ dst[idx] = keep ? src[idx] : 0.0f;
+ });
+}
+
+static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ GGML_ASSERT(src0);
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ dpct::queue_ptr main_stream = ctx.stream();
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+
+ const float * src0_dd = static_cast<const float *>(src0->data);
+ float * dst_dd = static_cast<float *>(dst->data);
+
+ const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
+
+ const int64_t ne0 = src0->ne[0];
+ const int64_t ne1 = src0->ne[1];
+ const int64_t ne2 = src0->ne[2];
+ const int64_t ne3 = src0->ne[3];
+
+ tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
+}
+
+
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
case GGML_OP_TRANSPOSE:
GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
break;
+ case GGML_OP_TRI:
+ ggml_sycl_op_tri(ctx, dst);
+ break;
case GGML_OP_DIAG_MASK_INF:
ggml_sycl_diag_mask_inf(ctx, dst);
break;
return true;
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
+ case GGML_OP_TRI:
+ {
+ const ggml_tensor * src0 = op->src[0];
+ return src0 &&
+ op->type == GGML_TYPE_F32 &&
+ ggml_is_contiguous(src0);
+ }
case GGML_OP_DIAG_MASK_INF:
return true;
case GGML_OP_SOFT_MAX: