/** Unary **/
struct ggml_webgpu_unary_pipeline_key {
- int type;
- int op;
- bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
- bool inplace;
+ int type;
+ int op;
+ bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
+ bool inplace;
+ ggml_tri_type ttype; // only used for GGML_OP_TRI
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
- return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
+ return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
+ ttype == other.ttype;
}
};
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.is_unary);
ggml_webgpu_hash_combine(seed, key.inplace);
+ ggml_webgpu_hash_combine(seed, key.ttype);
return seed;
}
};
.op = op,
.is_unary = is_unary,
.inplace = context.inplace,
+ .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0),
};
auto it = unary_pipelines.find(key);
variant += "_inplace";
}
+ if (op == GGML_OP_TRI) {
+ switch (key.ttype) {
+ case GGML_TRI_TYPE_LOWER:
+ defines.push_back("TRI_TYPE_LOWER");
+ variant += "_tri_type_lower";
+ break;
+ case GGML_TRI_TYPE_LOWER_DIAG:
+ defines.push_back("TRI_TYPE_LOWER_DIAG");
+ variant += "_tri_type_lower_diag";
+ break;
+ case GGML_TRI_TYPE_UPPER:
+ defines.push_back("TRI_TYPE_UPPER");
+ variant += "_tri_type_upper";
+ break;
+ case GGML_TRI_TYPE_UPPER_DIAG:
+ defines.push_back("TRI_TYPE_UPPER_DIAG");
+ variant += "_tri_upper_diag";
+ break;
+ default:
+ GGML_ABORT("Unsupported ggml_tri_type for unary shader");
+ }
+ }
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_unary, defines);
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
+ case GGML_OP_DIAG:
+ case GGML_OP_TRI:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_PAD:
return ggml_webgpu_pad(ctx, src0, node);
case GGML_OP_COS:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
+ case GGML_OP_DIAG:
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+ break;
+ case GGML_OP_TRI:
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+ break;
case GGML_OP_PAD:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
#define TYPE f32
#endif
-
@group(0) @binding(0)
var<storage, read_write> src: array<TYPE>;
return;
}
var i = gid.x;
- let i3 = i / (params.ne2 * params.ne1 * params.ne0);
- i = i % (params.ne2 * params.ne1 * params.ne0);
- let i2 = i / (params.ne1 * params.ne0);
- i = i % (params.ne1 * params.ne0);
- let i1 = i / params.ne0;
- let i0 = i % params.ne0;
+ let ne2 = params.ne2;
+#ifdef DIAG
+ let ne1 = params.ne0;
+#else
+ let ne1 = params.ne1;
+#endif
+ let ne0 = params.ne0;
+
+ let i3 = i / (ne2 * ne1 * ne0);
+ i = i % (ne2 * ne1 * ne0);
+ let i2 = i / (ne1 * ne0);
+ i = i % (ne1 * ne0);
+ let i1 = i / ne0;
+ let i0 = i % ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
let res_f32 = cos(f32(src[params.offset_src + src_idx]));
let res = TYPE(res_f32);
#endif
+#ifdef DIAG
+ let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1);
+#endif
+#ifdef TRI
+#ifdef TRI_TYPE_LOWER
+ let res = select(0.0, src[params.offset_src + src_idx], i0 < i1);
+#elif TRI_TYPE_LOWER_DIAG
+ let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1);
+#elif TRI_TYPE_UPPER
+ let res = select(0.0, src[params.offset_src + src_idx], i0 > i1);
+#elif TRI_TYPE_UPPER_DIAG
+ let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1);
+#endif
+#endif
#ifdef INPLACE
src[params.offset_src + src_idx] = res;