]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sycl: implement GGML_OP_TRI (llama/19089)
authorRachelMantel <redacted>
Fri, 30 Jan 2026 04:00:49 +0000 (06:00 +0200)
committerGeorgi Gerganov <redacted>
Fri, 30 Jan 2026 11:49:29 +0000 (13:49 +0200)
* sycl: implement GGML_OP_TRI

* docs: update ops.md for SYCL TRI

* docs: regenerate ops.md

* docs: update SYCL support for GGML_OP_TRI

src/ggml-sycl/ggml-sycl.cpp

index 3a4c092af5d1858ed194653a29b6796b25a6660f..d20b7ec57df33a6f077f53b6ee95188ebb8dd5f5 100644 (file)
@@ -2263,6 +2263,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
     diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
 }
 
+static void tri_f32_sycl(
+    const float * src,
+    float * dst,
+    const int64_t ne0,
+    const int64_t ne1,
+    const int64_t ne2,
+    const int64_t ne3,
+    const ggml_tri_type ttype,
+    dpct::queue_ptr main_stream
+) {
+    const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
+
+    main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
+        const int64_t idx = (int64_t) tid[0];
+
+        const int64_t i0 = idx % ne0;
+        const int64_t t1 = idx / ne0;
+        const int64_t i1 = t1 % ne1;
+
+        bool keep = false;
+        switch (ttype) {
+            case GGML_TRI_TYPE_LOWER:      keep = (i0 <  i1); break;
+            case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
+            case GGML_TRI_TYPE_UPPER:      keep = (i0 >  i1); break;
+            case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
+            default: keep = false; break;
+        }
+
+        dst[idx] = keep ? src[idx] : 0.0f;
+    });
+}
+
+static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    GGML_ASSERT(src0);
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    dpct::queue_ptr main_stream = ctx.stream();
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+
+    const float * src0_dd = static_cast<const float *>(src0->data);
+    float *       dst_dd  = static_cast<float *>(dst->data);
+
+    const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
+
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+    const int64_t ne3 = src0->ne[3];
+
+    tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
+}
+
+
 inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -3912,6 +3971,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_TRANSPOSE:
             GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
             break;
+        case GGML_OP_TRI:
+            ggml_sycl_op_tri(ctx, dst);
+            break;
         case GGML_OP_DIAG_MASK_INF:
             ggml_sycl_diag_mask_inf(ctx, dst);
             break;
@@ -4616,6 +4678,13 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return true;
         case GGML_OP_CONT:
             return op->src[0]->type != GGML_TYPE_BF16;
+        case GGML_OP_TRI:
+            {
+                const ggml_tensor * src0 = op->src[0];
+                return src0 &&
+                       op->type == GGML_TYPE_F32 &&
+                       ggml_is_contiguous(src0);
+            }
         case GGML_OP_DIAG_MASK_INF:
             return true;
         case GGML_OP_SOFT_MAX: