]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. (llama/19700)
authorMasashi Yoshimura <redacted>
Thu, 19 Feb 2026 16:18:30 +0000 (01:18 +0900)
committerGeorgi Gerganov <redacted>
Wed, 25 Feb 2026 10:32:13 +0000 (12:32 +0200)
* ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support.

* Fix to cast the src value to f32 before sin/cos computing.

src/ggml-webgpu/ggml-webgpu.cpp
src/ggml-webgpu/wgsl-shaders/unary.wgsl

index b5fee480562bd1bb3fdfdde10dc39d70eb362db3..1c00d3cb2b16c9cff0da95c169c622d53f8b8a84 100644 (file)
@@ -2008,6 +2008,14 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
             return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_LOG:
             return ggml_webgpu_unary_op(ctx, src0, node);
+        case GGML_OP_SQR:
+            return ggml_webgpu_unary_op(ctx, src0, node);
+        case GGML_OP_SQRT:
+            return ggml_webgpu_unary_op(ctx, src0, node);
+        case GGML_OP_SIN:
+            return ggml_webgpu_unary_op(ctx, src0, node);
+        case GGML_OP_COS:
+            return ggml_webgpu_unary_op(ctx, src0, node);
         case GGML_OP_PAD:
             return ggml_webgpu_pad(ctx, src0, node);
         case GGML_OP_ARGMAX:
@@ -2967,6 +2975,18 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
         case GGML_OP_LOG:
             supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
             break;
+        case GGML_OP_SQR:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
+        case GGML_OP_SQRT:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
+        case GGML_OP_SIN:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
+        case GGML_OP_COS:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
         case GGML_OP_PAD:
             supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
             break;
index d639d984970a04ec967b3e7325c22ec0cf1286c8..feaf6d0ac292d3b32f68679b70befef4500c04fe 100644 (file)
@@ -170,6 +170,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
 #ifdef TRUNC
     let res = trunc(src[params.offset_src + src_idx]);
 #endif
+#ifdef SQR
+    let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx];
+#endif
+#ifdef SQRT
+    let res = sqrt(src[params.offset_src + src_idx]);
+#endif
+#ifdef SIN
+    let res_f32 = sin(f32(src[params.offset_src + src_idx]));
+    let res = TYPE(res_f32);
+#endif
+#ifdef COS
+    let res_f32 = cos(f32(src[params.offset_src + src_idx]));
+    let res = TYPE(res_f32);
+#endif
 
 #ifdef INPLACE
     src[params.offset_src + src_idx] = res;