From: Masashi Yoshimura Date: Thu, 19 Feb 2026 16:18:30 +0000 (+0900) Subject: ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. (llama/19700) X-Git-Tag: v0.9.8~117 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=85028134904e6fc03a5470c41d1c385cb31ad5b8;p=pkg%2Fggml%2Fsources%2Fggml ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. (llama/19700) * ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. * Fix to cast the src value to f32 before sin/cos computing. --- diff --git a/src/ggml-webgpu/ggml-webgpu.cpp b/src/ggml-webgpu/ggml-webgpu.cpp index b5fee480..1c00d3cb 100644 --- a/src/ggml-webgpu/ggml-webgpu.cpp +++ b/src/ggml-webgpu/ggml-webgpu.cpp @@ -2008,6 +2008,14 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_LOG: return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SQR: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SQRT: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SIN: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_COS: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: @@ -2967,6 +2975,18 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_LOG: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; + case GGML_OP_SQR: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_SQRT: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_SIN: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_COS: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; case GGML_OP_PAD: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/src/ggml-webgpu/wgsl-shaders/unary.wgsl index d639d984..feaf6d0a 100644 --- a/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -170,6 +170,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { #ifdef TRUNC let res = trunc(src[params.offset_src + src_idx]); #endif +#ifdef SQR + let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx]; +#endif +#ifdef SQRT + let res = sqrt(src[params.offset_src + src_idx]); +#endif +#ifdef SIN + let res_f32 = sin(f32(src[params.offset_src + src_idx])); + let res = TYPE(res_f32); +#endif +#ifdef COS + let res_f32 = cos(f32(src[params.offset_src + src_idx])); + let res = TYPE(res_f32); +#endif #ifdef INPLACE src[params.offset_src + src_idx] = res;