"WebGPU: WebGPU","ARGMAX","type=f32,ne=[1024,12,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[2000,10,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[5438,3,1,1]","support","1","yes","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU"
-"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU"
+"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,1,1,1],v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[2,1,1,1],v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,2,1,1],v=0","support","0","no","WebGPU"
}
};
+/** Repeat **/
+
+struct ggml_webgpu_repeat_pipeline_key {
+ int type;
+
+ bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
+};
+
+struct ggml_webgpu_repeat_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.type);
+ return seed;
+ }
+};
+
/** Binary **/
struct ggml_webgpu_binary_pipeline_key {
binary_pipelines; // type/op/inplace/overlap
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
concat_pipelines; // type
+ std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
+ repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
}
std::vector<std::string> defines;
- std::string variant = "concat";
+ std::string variant = "concat";
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
- auto processed = preprocessor.preprocess(wgsl_concat, defines);
- auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
- decisions->wg_size = context.max_wg_size;
+ auto processed = preprocessor.preprocess(wgsl_concat, defines);
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
- pipeline.context = decisions;
- concat_pipelines[key] = pipeline;
+ pipeline.context = decisions;
+ concat_pipelines[key] = pipeline;
return concat_pipelines[key];
}
+ webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
+ ggml_webgpu_repeat_pipeline_key key = {
+ .type = context.dst->type,
+ };
+
+ auto it = repeat_pipelines.find(key);
+ if (it != repeat_pipelines.end()) {
+ return it->second;
+ }
+
+ std::vector<std::string> defines;
+ std::string variant = "repeat";
+
+ switch (key.type) {
+ case GGML_TYPE_F32:
+ defines.push_back("TYPE_F32");
+ variant += "_f32";
+ break;
+ case GGML_TYPE_I32:
+ defines.push_back("TYPE_I32");
+ variant += "_i32";
+ break;
+ case GGML_TYPE_I16:
+ defines.push_back("TYPE_I16");
+ variant += "_i16";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for repeat shader");
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ auto processed = preprocessor.preprocess(wgsl_repeat, defines);
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+ pipeline.context = decisions;
+ repeat_pipelines[key] = pipeline;
+ return repeat_pipelines[key];
+ }
+
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool has_mask = context.src3 != nullptr;
const bool has_sinks = context.src4 != nullptr;
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
+static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+ std::vector<uint32_t> params = { ne,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
+ ggml_type_size(src0->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->ne[0]),
+ (uint32_t) (src0->ne[1]),
+ (uint32_t) (src0->ne[2]),
+ (uint32_t) (src0->ne[3]),
+ (uint32_t) (dst->ne[0]),
+ (uint32_t) (dst->ne[1]),
+ (uint32_t) (dst->ne[2]) };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
+ .src0 = src0,
+ .dst = dst,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ };
+
+ webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
int inplace = ggml_webgpu_tensor_equal(src, dst);
return ggml_webgpu_binary_op(ctx, src0, src1, node);
case GGML_OP_CONCAT:
return ggml_webgpu_concat(ctx, src0, src1, node);
+ case GGML_OP_REPEAT:
+ return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_ROPE:
/* .iface = */ {
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
/* .alloc_buffer = */
- ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
- ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
- ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
- ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
+ ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
+ ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
+ ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
+ ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
},
/* .device = */
dev,
case GGML_OP_CONCAT:
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
break;
+ case GGML_OP_REPEAT:
+ supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
+ break;
case GGML_OP_CPY:
case GGML_OP_CONT:
supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&