]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sycl: implement GGML_OP_TRI (#19089)
authorRachelMantel <redacted>
Fri, 30 Jan 2026 04:00:49 +0000 (06:00 +0200)
committerGitHub <redacted>
Fri, 30 Jan 2026 04:00:49 +0000 (12:00 +0800)
* sycl: implement GGML_OP_TRI

* docs: update ops.md for SYCL TRI

* docs: regenerate ops.md

* docs: update SYCL support for GGML_OP_TRI

docs/ops.md
docs/ops/SYCL.csv
ggml/src/ggml-sycl/ggml-sycl.cpp

index c066ab5a858a2928e7da8f386e00c3396e8548b3..b8e034780307099eb28ccc8c862a1c45464bc0c5 100644 (file)
@@ -114,7 +114,7 @@ Legend:
 |                             TANH | โŒ | โœ… | โœ… | ๐ŸŸก | ๐ŸŸก | โœ… | โœ… | ๐ŸŸก | โœ… | โŒ | โŒ |
 |               TIMESTEP_EMBEDDING | โŒ | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ | โŒ | โŒ |
 |                            TOP_K | โŒ | โŒ | โœ… | โŒ | โœ… | โŒ | โŒ | ๐ŸŸก | โœ… | โŒ | โŒ |
-|                              TRI | รข\9d\8c | รข\9d\8c | รข\9c\85 | รข\9c\85 | รข\9c\85 | รข\9d\8c | รข\9d\8c | โœ… | โŒ | โŒ | โŒ |
+|                              TRI | รข\9d\8c | รข\9d\8c | รข\9c\85 | รข\9c\85 | รข\9c\85 | รข\9d\8c | รข\9c\85 | โœ… | โŒ | โŒ | โŒ |
 |                            TRUNC | โŒ | โŒ | โœ… | ๐ŸŸก | โŒ | โŒ | ๐ŸŸก | ๐ŸŸก | โœ… | โŒ | โŒ |
 |                          UPSCALE | โŒ | ๐ŸŸก | โœ… | โœ… | ๐ŸŸก | ๐ŸŸก | ๐ŸŸก | ๐ŸŸก | โŒ | โŒ | โŒ |
 |                            XIELU | โŒ | โŒ | โœ… | โŒ | โŒ | โŒ | โŒ | โŒ | โœ… | โŒ | โŒ |
index 091a5caed7a03fc1494301fda0315454f74fc970..255b4ef4739efd5308f12afeaed4321fc65c4c43 100644 (file)
 "SYCL0","CUMSUM","type=f32,ne=[375960,1,1,1]","support","0","no","SYCL"
 "SYCL0","CUMSUM","type=f32,ne=[20481,4,1,1]","support","0","no","SYCL"
 "SYCL0","XIELU","type=f32,ne=[10,5,4,3]","support","0","no","SYCL"
-"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","SYCL"
-"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","SYCL"
-"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","SYCL"
-"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","SYCL"
+"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","SYCL"
+"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","SYCL"
+"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","SYCL"
+"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","SYCL"
 "SYCL0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","SYCL"
 "SYCL0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","SYCL"
 "SYCL0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","SYCL"
index 3a4c092af5d1858ed194653a29b6796b25a6660f..d20b7ec57df33a6f077f53b6ee95188ebb8dd5f5 100644 (file)
@@ -2263,6 +2263,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
     diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
 }
 
+static void tri_f32_sycl(
+    const float * src,
+    float * dst,
+    const int64_t ne0,
+    const int64_t ne1,
+    const int64_t ne2,
+    const int64_t ne3,
+    const ggml_tri_type ttype,
+    dpct::queue_ptr main_stream
+) {
+    const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
+
+    main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
+        const int64_t idx = (int64_t) tid[0];
+
+        const int64_t i0 = idx % ne0;
+        const int64_t t1 = idx / ne0;
+        const int64_t i1 = t1 % ne1;
+
+        bool keep = false;
+        switch (ttype) {
+            case GGML_TRI_TYPE_LOWER:      keep = (i0 <  i1); break;
+            case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
+            case GGML_TRI_TYPE_UPPER:      keep = (i0 >  i1); break;
+            case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
+            default: keep = false; break;
+        }
+
+        dst[idx] = keep ? src[idx] : 0.0f;
+    });
+}
+
+static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    GGML_ASSERT(src0);
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    dpct::queue_ptr main_stream = ctx.stream();
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+
+    const float * src0_dd = static_cast<const float *>(src0->data);
+    float *       dst_dd  = static_cast<float *>(dst->data);
+
+    const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
+
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+    const int64_t ne3 = src0->ne[3];
+
+    tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
+}
+
+
 inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -3912,6 +3971,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_TRANSPOSE:
             GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
             break;
+        case GGML_OP_TRI:
+            ggml_sycl_op_tri(ctx, dst);
+            break;
         case GGML_OP_DIAG_MASK_INF:
             ggml_sycl_diag_mask_inf(ctx, dst);
             break;
@@ -4616,6 +4678,13 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return true;
         case GGML_OP_CONT:
             return op->src[0]->type != GGML_TYPE_BF16;
+        case GGML_OP_TRI:
+            {
+                const ggml_tensor * src0 = op->src[0];
+                return src0 &&
+                       op->type == GGML_TYPE_F32 &&
+                       ggml_is_contiguous(src0);
+            }
         case GGML_OP_DIAG_MASK_INF:
             return true;
         case GGML_OP_SOFT_MAX: