]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-webgpu: Add supports for `DIAG` and `TRI` (#20664)
authorMasashi Yoshimura <redacted>
Thu, 19 Mar 2026 04:08:35 +0000 (13:08 +0900)
committerGitHub <redacted>
Thu, 19 Mar 2026 04:08:35 +0000 (21:08 -0700)
* Add supports for DIAG and TRI.

* Remove extra ttype and add a comment for TRI op.

docs/ops.md
docs/ops/WebGPU.csv
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl

index 1357771442011521346fb7e7e8fb728db86a8fff..47534b1401c92f550df226fdd0e851c2ea9dc156 100644 (file)
@@ -37,7 +37,7 @@ Legend:
 |               CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |          CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                           CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
-|                             DIAG | â\9d\8c | â\9d\8c | â\9c\85 | â\9c\85 | â\9d\8c | â\9d\8c | â\9d\8c | â\9c\85 | â\9d\8c | ❌ | ❌ |
+|                             DIAG | â\9d\8c | â\9d\8c | â\9c\85 | â\9c\85 | â\9d\8c | â\9d\8c | â\9d\8c | â\9c\85 | â\9c\85 | ❌ | ❌ |
 |                    DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                              DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
 |                              DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
@@ -115,7 +115,7 @@ Legend:
 |                             TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |               TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                            TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
-|                              TRI | â\9d\8c | â\9d\8c | â\9c\85 | â\9c\85 | â\9c\85 | â\9d\8c | â\9c\85 | â\9c\85 | â\9d\8c | ❌ | ❌ |
+|                              TRI | â\9d\8c | â\9d\8c | â\9c\85 | â\9c\85 | â\9c\85 | â\9d\8c | â\9c\85 | â\9c\85 | â\9c\85 | ❌ | ❌ |
 |                            TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
 |                          UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
 |                            XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
index b7761b9dd3f91ac2f3d01a663d71dee373980838..56bae2f3c81d7bb30dc82eb46369a88c9e3e6f9c 100644 (file)
 "WebGPU: WebGPU","CUMSUM","type=f32,ne=[375960,1,1,1]","support","1","yes","WebGPU"
 "WebGPU: WebGPU","CUMSUM","type=f32,ne=[20481,4,1,1]","support","1","yes","WebGPU"
 "WebGPU: WebGPU","XIELU","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU"
-"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","WebGPU"
-"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","WebGPU"
-"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","WebGPU"
-"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","WebGPU"
+"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","WebGPU"
+"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","WebGPU"
+"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","WebGPU"
 "WebGPU: WebGPU","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","WebGPU"
 "WebGPU: WebGPU","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","WebGPU"
 "WebGPU: WebGPU","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","WebGPU"
-"WebGPU: WebGPU","DIAG","type=f32,ne=[10,1,4,3]","support","0","no","WebGPU"
-"WebGPU: WebGPU","DIAG","type=f32,ne=[79,1,19,13]","support","0","no","WebGPU"
-"WebGPU: WebGPU","DIAG","type=f32,ne=[256,1,8,16]","support","0","no","WebGPU"
+"WebGPU: WebGPU","DIAG","type=f32,ne=[10,1,4,3]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","DIAG","type=f32,ne=[79,1,19,13]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","DIAG","type=f32,ne=[256,1,8,16]","support","1","yes","WebGPU"
 "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","WebGPU"
 "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","WebGPU"
 "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","WebGPU"
index 3d7e59fddf322d5ccfee1367edc01f5fb8fb4382..ad665e4de93c157d41dbf5099afd9254b2978c5f 100644 (file)
@@ -244,13 +244,15 @@ struct ggml_webgpu_binary_pipeline_key_hash {
 /** Unary **/
 
 struct ggml_webgpu_unary_pipeline_key {
-    int  type;
-    int  op;
-    bool is_unary;  // many unary operators fall under the GGML_OP_UNARY umbrella
-    bool inplace;
+    int           type;
+    int           op;
+    bool          is_unary;  // many unary operators fall under the GGML_OP_UNARY umbrella
+    bool          inplace;
+    ggml_tri_type ttype;     // only used for GGML_OP_TRI
 
     bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
-        return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
+        return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
+               ttype == other.ttype;
     }
 };
 
@@ -261,6 +263,7 @@ struct ggml_webgpu_unary_pipeline_key_hash {
         ggml_webgpu_hash_combine(seed, key.op);
         ggml_webgpu_hash_combine(seed, key.is_unary);
         ggml_webgpu_hash_combine(seed, key.inplace);
+        ggml_webgpu_hash_combine(seed, key.ttype);
         return seed;
     }
 };
@@ -1058,6 +1061,7 @@ class ggml_webgpu_shader_lib {
                  .op       = op,
                  .is_unary = is_unary,
                  .inplace  = context.inplace,
+                 .ttype    = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0),
         };
 
         auto it = unary_pipelines.find(key);
@@ -1088,6 +1092,29 @@ class ggml_webgpu_shader_lib {
             variant += "_inplace";
         }
 
+        if (op == GGML_OP_TRI) {
+            switch (key.ttype) {
+                case GGML_TRI_TYPE_LOWER:
+                    defines.push_back("TRI_TYPE_LOWER");
+                    variant += "_tri_type_lower";
+                    break;
+                case GGML_TRI_TYPE_LOWER_DIAG:
+                    defines.push_back("TRI_TYPE_LOWER_DIAG");
+                    variant += "_tri_type_lower_diag";
+                    break;
+                case GGML_TRI_TYPE_UPPER:
+                    defines.push_back("TRI_TYPE_UPPER");
+                    variant += "_tri_type_upper";
+                    break;
+                case GGML_TRI_TYPE_UPPER_DIAG:
+                    defines.push_back("TRI_TYPE_UPPER_DIAG");
+                    variant += "_tri_upper_diag";
+                    break;
+                default:
+                    GGML_ABORT("Unsupported ggml_tri_type for unary shader");
+            }
+        }
+
         defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
 
         auto processed           = preprocessor.preprocess(wgsl_unary, defines);
index 3976a171d166a7c7d863cbf7b5c91a0981d16aaa..4b0eeac0f42115981e5708eac751da557a0e89eb 100644 (file)
@@ -2209,6 +2209,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_SQRT:
         case GGML_OP_SIN:
         case GGML_OP_COS:
+        case GGML_OP_DIAG:
+        case GGML_OP_TRI:
             return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_PAD:
             return ggml_webgpu_pad(ctx, src0, node);
@@ -3201,6 +3203,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
         case GGML_OP_COS:
             supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
             break;
+        case GGML_OP_DIAG:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
+        case GGML_OP_TRI:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
         case GGML_OP_PAD:
             supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
             break;
index feaf6d0ac292d3b32f68679b70befef4500c04fe..21beb9bb94d1e5303ffe6a08986e9a68797ead57 100644 (file)
@@ -5,7 +5,6 @@ enable f16;
 #define TYPE f32
 #endif
 
-
 @group(0) @binding(0)
 var<storage, read_write> src: array<TYPE>;
 
@@ -57,12 +56,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
       return;
     }
     var i = gid.x;
-    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
-    i = i % (params.ne2 * params.ne1 * params.ne0);
-    let i2 = i / (params.ne1 * params.ne0);
-    i = i % (params.ne1 * params.ne0);
-    let i1 = i / params.ne0;
-    let i0 = i % params.ne0;
+    let ne2 = params.ne2;
+#ifdef DIAG
+    let ne1 = params.ne0;
+#else
+    let ne1 = params.ne1;
+#endif
+    let ne0 = params.ne0;
+
+    let i3 = i / (ne2 * ne1 * ne0);
+    i = i % (ne2 * ne1 * ne0);
+    let i2 = i / (ne1 * ne0);
+    i = i % (ne1 * ne0);
+    let i1 = i / ne0;
+    let i0 = i % ne0;
 
     let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
                   i2 * params.stride_src2 + i3 * params.stride_src3;
@@ -184,6 +191,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
     let res_f32 = cos(f32(src[params.offset_src + src_idx]));
     let res = TYPE(res_f32);
 #endif
+#ifdef DIAG
+    let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1);
+#endif
+#ifdef TRI
+#ifdef TRI_TYPE_LOWER
+    let res = select(0.0, src[params.offset_src + src_idx], i0 < i1);
+#elif TRI_TYPE_LOWER_DIAG
+    let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1);
+#elif TRI_TYPE_UPPER
+    let res = select(0.0, src[params.offset_src + src_idx], i0 > i1);
+#elif TRI_TYPE_UPPER_DIAG
+    let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1);
+#endif
+#endif
 
 #ifdef INPLACE
     src[params.offset_src + src_idx] = res;