"WebGPU: WebGPU","CUMSUM","type=f32,ne=[375960,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","CUMSUM","type=f32,ne=[20481,4,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","XIELU","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU"
-"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","WebGPU"
-"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","WebGPU"
-"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","WebGPU"
-"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","WebGPU"
+"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","WebGPU"
+"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","WebGPU"
+"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","WebGPU"
"WebGPU: WebGPU","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","WebGPU"
"WebGPU: WebGPU","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","WebGPU"
"WebGPU: WebGPU","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","WebGPU"
-"WebGPU: WebGPU","DIAG","type=f32,ne=[10,1,4,3]","support","0","no","WebGPU"
-"WebGPU: WebGPU","DIAG","type=f32,ne=[79,1,19,13]","support","0","no","WebGPU"
-"WebGPU: WebGPU","DIAG","type=f32,ne=[256,1,8,16]","support","0","no","WebGPU"
+"WebGPU: WebGPU","DIAG","type=f32,ne=[10,1,4,3]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","DIAG","type=f32,ne=[79,1,19,13]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","DIAG","type=f32,ne=[256,1,8,16]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","WebGPU"
/** Unary **/
struct ggml_webgpu_unary_pipeline_key {
- int type;
- int op;
- bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
- bool inplace;
+ int type;
+ int op;
+ bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
+ bool inplace;
+ ggml_tri_type ttype; // only used for GGML_OP_TRI
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
- return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
+ return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
+ ttype == other.ttype;
}
};
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.is_unary);
ggml_webgpu_hash_combine(seed, key.inplace);
+ ggml_webgpu_hash_combine(seed, key.ttype);
return seed;
}
};
.op = op,
.is_unary = is_unary,
.inplace = context.inplace,
+ .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0),
};
auto it = unary_pipelines.find(key);
variant += "_inplace";
}
+ if (op == GGML_OP_TRI) {
+ switch (key.ttype) {
+ case GGML_TRI_TYPE_LOWER:
+ defines.push_back("TRI_TYPE_LOWER");
+ variant += "_tri_type_lower";
+ break;
+ case GGML_TRI_TYPE_LOWER_DIAG:
+ defines.push_back("TRI_TYPE_LOWER_DIAG");
+ variant += "_tri_type_lower_diag";
+ break;
+ case GGML_TRI_TYPE_UPPER:
+ defines.push_back("TRI_TYPE_UPPER");
+ variant += "_tri_type_upper";
+ break;
+ case GGML_TRI_TYPE_UPPER_DIAG:
+ defines.push_back("TRI_TYPE_UPPER_DIAG");
+ variant += "_tri_upper_diag";
+ break;
+ default:
+ GGML_ABORT("Unsupported ggml_tri_type for unary shader");
+ }
+ }
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_unary, defines);