From: Abhijit Ramesh Date: Wed, 1 Apr 2026 09:58:53 +0000 (+0300) Subject: ggml-webgpu: port all AOT operators to JIT (llama/20728) X-Git-Tag: v0.9.10~9 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=175e9fbdc9fae8aada32084522808760048f9f6c;p=pkg%2Fggml%2Fsources%2Fggml ggml-webgpu: port all AOT operators to JIT (llama/20728) * port cpy pipeline to shader lib with JIT compilation * port glu pipeline to shader lib with JIT compilation * port rope pipeline to shader lib with JIT compilation * port soft_max pipeline to shader lib with JIT compilation * removed unused functions from embed_wgsl.py which were used for old AOT template expansion --- diff --git a/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 59861ac1..97863f40 100644 --- a/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -535,6 +535,95 @@ struct ggml_webgpu_mul_mat_shader_decisions { uint32_t mul_mat_wg_size; }; +/** Cpy **/ + +struct ggml_webgpu_cpy_pipeline_key { + ggml_type src_type; + ggml_type dst_type; + + bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const { + return src_type == other.src_type && dst_type == other.dst_type; + } +}; + +struct ggml_webgpu_cpy_pipeline_key_hash { + size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + return seed; + } +}; + +/** Glu **/ + +struct ggml_webgpu_glu_pipeline_key { + ggml_glu_op glu_op; + ggml_type type; + bool split; + + bool operator==(const ggml_webgpu_glu_pipeline_key & other) const { + return glu_op == other.glu_op && type == other.type && split == other.split; + } +}; + +struct ggml_webgpu_glu_pipeline_key_hash { + size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.glu_op); + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.split); + return seed; + } +}; + +/** Rope **/ + +struct ggml_webgpu_rope_pipeline_key { + ggml_type type; + bool inplace; + bool has_ff; + + bool operator==(const ggml_webgpu_rope_pipeline_key & other) const { + return type == other.type && inplace == other.inplace && has_ff == other.has_ff; + } +}; + +struct ggml_webgpu_rope_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.has_ff); + return seed; + } +}; + +/** SoftMax **/ + +struct ggml_webgpu_soft_max_pipeline_key { + ggml_type mask_type; + bool has_mask; + bool has_sink; + bool inplace; + + bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const { + return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink && + inplace == other.inplace; + } +}; + +struct ggml_webgpu_soft_max_pipeline_key_hash { + size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.mask_type); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sink); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; @@ -582,6 +671,12 @@ class ggml_webgpu_shader_lib { std::unordered_map set_rows_pipelines; std::unordered_map set_pipelines; + std::unordered_map cpy_pipelines; + std::unordered_map glu_pipelines; + std::unordered_map + rope_pipelines; + std::unordered_map + soft_max_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -1679,6 +1774,236 @@ class ggml_webgpu_shader_lib { return flash_attn_pipelines[key]; } + webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_cpy_pipeline_key key = { + .src_type = context.src0->type, + .dst_type = context.dst->type, + }; + + auto it = cpy_pipelines.find(key); + if (it != cpy_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "cpy"; + + switch (key.src_type) { + case GGML_TYPE_F32: + defines.push_back("SRC_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src type for cpy shader"); + } + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("DST_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported dst type for cpy shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_cpy, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + cpy_pipelines[key] = pipeline; + return cpy_pipelines[key]; + } + + webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_glu_pipeline_key key = { + .glu_op = ggml_get_glu_op(context.dst), + .type = context.dst->type, + .split = (context.src1 != nullptr), + }; + + auto it = glu_pipelines.find(key); + if (it != glu_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "glu"; + + switch (key.glu_op) { + case GGML_GLU_OP_REGLU: + defines.push_back("OP_REGLU"); + variant += "_reglu"; + break; + case GGML_GLU_OP_GEGLU: + defines.push_back("OP_GEGLU"); + variant += "_geglu"; + break; + case GGML_GLU_OP_SWIGLU: + defines.push_back("OP_SWIGLU"); + variant += "_swiglu"; + break; + case GGML_GLU_OP_SWIGLU_OAI: + defines.push_back("OP_SWIGLU_OAI"); + variant += "_swiglu_oai"; + break; + case GGML_GLU_OP_GEGLU_ERF: + defines.push_back("OP_GEGLU_ERF"); + variant += "_geglu_erf"; + break; + case GGML_GLU_OP_GEGLU_QUICK: + defines.push_back("OP_GEGLU_QUICK"); + variant += "_geglu_quick"; + break; + default: + GGML_ABORT("Unsupported GLU op"); + } + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for GLU shader"); + } + + if (key.split) { + variant += "_split"; + } else { + defines.push_back("NO_SPLIT"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_glu, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + glu_pipelines[key] = pipeline; + return glu_pipelines[key]; + } + + webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rope_pipeline_key key = { + .type = context.dst->type, + .inplace = context.inplace, + .has_ff = (context.src2 != nullptr), + }; + + auto it = rope_pipelines.find(key); + if (it != rope_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "rope"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for ROPE shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + if (key.has_ff) { + defines.push_back("FF_FUNC"); + variant += "_ff"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rope, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + rope_pipelines[key] = pipeline; + return rope_pipelines[key]; + } + + webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_soft_max_pipeline_key key = { + .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32, + .has_mask = (context.src1 != nullptr), + .has_sink = (context.src2 != nullptr), + .inplace = context.inplace, + }; + + auto it = soft_max_pipelines.find(key); + if (it != soft_max_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "soft_max"; + + if (key.has_mask) { + defines.push_back("HAS_MASK"); + switch (key.mask_type) { + case GGML_TYPE_F32: + defines.push_back("MASK_F32"); + variant += "_mask_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("MASK_F16"); + variant += "_mask_f16"; + break; + default: + GGML_ABORT("Unsupported type for SOFT_MAX shader"); + } + } + + if (key.has_sink) { + defines.push_back("HAS_SINK"); + variant += "_sink"; + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_soft_max, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + soft_max_pipelines[key] = pipeline; + return soft_max_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/src/ggml-webgpu/ggml-webgpu.cpp b/src/ggml-webgpu/ggml-webgpu.cpp index 5e16f84d..fa3c492a 100644 --- a/src/ggml-webgpu/ggml-webgpu.cpp +++ b/src/ggml-webgpu/ggml-webgpu.cpp @@ -364,13 +364,6 @@ struct webgpu_context_struct { wgpu::Buffer set_rows_dev_error_buf; wgpu::Buffer set_rows_host_error_buf; - std::map> cpy_pipelines; // src_type, dst_type - - std::map>> rope_pipelines; // type, ff, inplace - std::map>> glu_pipelines; // glu_op, type, split - - std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - size_t memset_bytes_per_thread; }; @@ -849,6 +842,16 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 } static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { @@ -875,9 +878,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type], - params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { @@ -1914,6 +1916,19 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = ggml_webgpu_tensor_equal(src0, dst), + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + const int inplace = ggml_webgpu_tensor_equal(src0, dst); const int has_freq_factor = (src2 != nullptr); @@ -1996,12 +2011,22 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + const int split = (src1 != nullptr); std::vector params = { @@ -2048,8 +2073,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2109,9 +2133,20 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here - const int has_sink = (src2 != nullptr); + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = ggml_webgpu_tensor_equal(src0, dst), + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); @@ -2120,15 +2155,15 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), @@ -2136,8 +2171,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], (uint32_t) src0->ne[2], - mask_type < 2 ? (uint32_t) src1->ne[2] : 0, - mask_type < 2 ? (uint32_t) src1->ne[3] : 0, + has_mask ? (uint32_t) src1->ne[2] : 0, + has_mask ? (uint32_t) src1->ne[3] : 0, *(uint32_t *) dst->op_params, // scale *(uint32_t *) &max_bias, *(uint32_t *) &n_head_log2, @@ -2152,7 +2187,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, src0) } }; uint32_t binding_num = 1; - if (mask_type < 2) { + if (has_mask) { entries.push_back({ .binding = binding_num, .buffer = ggml_webgpu_tensor_buf(src1), .offset = ggml_webgpu_tensor_align_offset(ctx, src1), @@ -2173,9 +2208,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, - ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, - ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst)); } static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -2885,139 +2918,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); -} - -static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); -} - -static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - // REGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); - - // GEGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); - - // SWIGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); - - // SWIGLU_OAI - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); - - // GEGLU_ERF - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); - - // GEGLU_QUICK - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); -} - -static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - // f32 (no mask) - webgpu_ctx->soft_max_pipelines[2][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); - webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); - - // f32 mask (mask_type = 0) - webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); - webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[0][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, - "soft_max_f32_mask_f32_sink_inplace", constants); - - // f16 mask (mask_type = 1) - webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); - webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); - webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - webgpu_ctx->soft_max_pipelines[1][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, - "soft_max_f32_mask_f16_sink_inplace", constants); -} - static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { wgpu::RequestAdapterOptions options = {}; @@ -3183,10 +3083,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); - ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - ggml_webgpu_init_rope_pipeline(webgpu_ctx); - ggml_webgpu_init_glu_pipeline(webgpu_ctx); - ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf, diff --git a/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/src/ggml-webgpu/wgsl-shaders/cpy.wgsl new file mode 100644 index 00000000..fa3bdf4e --- /dev/null +++ b/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -0,0 +1,81 @@ +enable f16; + +#ifdef SRC_F32 +#define SRC_TYPE f32 +#elif defined(SRC_F16) +#define SRC_TYPE f16 +#endif + +#ifdef DST_F32 +#define DST_TYPE f32 +#elif defined(DST_F16) +#define DST_TYPE f16 +#elif defined(DST_I32) +#define DST_TYPE i32 +#endif + +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var dst: array; + +struct Params{ + ne: u32, + offset_src: u32, + offset_dst: u32, + + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +@group(0) @binding(2) +var params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx])); +} + diff --git a/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 8b5cfe71..79a3a959 100755 --- a/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -1,41 +1,8 @@ import os import re -import ast import argparse -def extract_block(text, name): - pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)' - match = re.search(pattern, text, re.DOTALL) - if not match: - raise ValueError(f"Missing block: {name}") - return match.group(1).strip() - - -def parse_decls(decls_text): - decls = {} - for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL): - decls[name.strip()] = code.strip() - return decls - - -def replace_repl_placeholders(variant, template_map): - for repl, code in variant["REPLS"].items(): - for key, val in template_map.items(): - # Match "key" and avoid matching subsequences using by using \b - code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code) - variant["REPLS"][repl] = code - return variant - - -def replace_placeholders(shader_text, replacements): - for key, val in replacements.items(): - # Match {{KEY}} literally, where KEY is escaped - pattern = r'{{\s*' + re.escape(key) + r'\s*}}' - shader_text = re.sub(pattern, str(val), shader_text) - return shader_text - - def expand_includes(shader, input_dir): """ Replace #include "file" lines in the text with the contents of that file. @@ -98,84 +65,24 @@ def write_shader(shader_name, shader_code, output_dir, outfile, input_dir): outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n') -def generate_variants(fname, input_dir, output_dir, outfile): - shader_path = os.path.join(input_dir, fname) - shader_base_name = fname.split(".")[0] - - with open(shader_path, "r", encoding="utf-8") as f: - text = f.read() - - try: - variants = ast.literal_eval(extract_block(text, "VARIANTS")) - except ValueError: - write_shader(shader_base_name, text, output_dir, outfile, input_dir) - else: - try: - decls_map = parse_decls(extract_block(text, "DECLS")) - except ValueError: - decls_map = {} - try: - templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES")) - except ValueError: - templates_map = {} - - for fname in sorted(os.listdir(input_dir)): - if fname.endswith(".tmpl"): - tmpl_path = os.path.join(input_dir, fname) - with open(tmpl_path, "r", encoding="utf-8") as f_tmpl: - decls = f_tmpl.read() - decls_map.update(parse_decls(decls)) - - shader_template = extract_block(text, "SHADER") - for variant in variants: - if "DECLS" in variant: - decls = variant["DECLS"] - else: - decls = [] - decls_code = "" - for key in decls: - if key not in decls_map: - raise ValueError(f"DECLS key '{key}' not found.") - decls_code += decls_map[key] + "\n\n" - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) - if "REPLS" in variant: - variant = replace_repl_placeholders(variant, templates_map) - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - # second run to expand placeholders in repl_template - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - final_shader = expand_includes(final_shader, input_dir) - - if "SHADER_NAME" in variant: - output_name = variant["SHADER_NAME"] - elif "SHADER_SUFFIX" in variant: - output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"] - elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) - elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]]) - elif "REPLS" in variant and "TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] - else: - output_name = shader_base_name - write_shader(output_name, final_shader, output_dir, outfile, input_dir) - - def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", required=True) parser.add_argument("--output_file", required=True) - parser.add_argument("--output_dir") args = parser.parse_args() - if args.output_dir: - os.makedirs(args.output_dir, exist_ok=True) - with open(args.output_file, "w", encoding="utf-8") as out: out.write("// Auto-generated shader embedding\n") out.write("#include \n\n") for fname in sorted(os.listdir(args.input_dir)): if fname.endswith(".wgsl"): - generate_variants(fname, args.input_dir, args.output_dir, out) + shader_path = os.path.join(args.input_dir, fname) + shader_name = fname.replace(".wgsl", "") + + with open(shader_path, "r", encoding="utf-8") as f: + shader_code = f.read() + + write_shader(shader_name, shader_code, None, out, args.input_dir) if __name__ == "__main__": diff --git a/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/src/ggml-webgpu/wgsl-shaders/glu.wgsl new file mode 100644 index 00000000..e6d7608c --- /dev/null +++ b/src/ggml-webgpu/wgsl-shaders/glu.wgsl @@ -0,0 +1,155 @@ +enable f16; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +#ifdef OP_REGLU +fn op(a: DataType, b: DataType) -> DataType { + return max(a, 0) * b; +} +#endif + +#ifdef OP_GEGLU +const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876; +const GELU_COEF_A: DataType = 0.044715; + +fn op(a: DataType, b: DataType) -> DataType { + let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); + return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b; +} +#endif + +#ifdef OP_SWIGLU +fn op(a: DataType, b: DataType) -> DataType { + return a / (1.0 + exp(-a)) * b; +} +#endif +#ifdef OP_SWIGLU_OAI +fn op(a: f32, b: f32) -> f32 { + let xi = min(a, params.limit); + let gi = max(min(b, params.limit), -params.limit); + var out_glu = xi / (1.0 + exp(-xi * params.alpha)); + out_glu = out_glu * (1.0 + gi); + return out_glu; +} +#endif +#ifdef OP_GEGLU_ERF +const p_erf: DataType = 0.3275911; +const a1_erf: DataType = 0.254829592; +const a2_erf: DataType = -0.284496736; +const a3_erf: DataType = 1.421413741; +const a4_erf: DataType = -1.453152027; +const a5_erf: DataType = 1.061405429; +const SQRT_2_INV: DataType = 0.7071067811865476; + +fn op(a: DataType, b: DataType) -> DataType { + let a_div_sqr2 = a * SQRT_2_INV; + let sign_x = sign(a_div_sqr2); + let x = abs(a_div_sqr2); + let t = 1.0 / (1.0 + p_erf * x); + let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); + let erf_approx = sign_x * y; + return 0.5 * a * (1.0 + erf_approx) * b; +} +#endif +#ifdef OP_GEGLU_QUICK +const GELU_QUICK_COEF: DataType = -1.702; + +fn op(a: DataType, b: DataType) -> DataType { + return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; +} +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + swapped: u32, + alpha: f32, + limit: f32, +} + +@group(0) @binding(0) +var src0: array; + +#ifdef NO_SPLIT +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn a_value(base: u32) -> DataType { + let offset: u32 = select(0, params.ne0, params.swapped != 0); + return src0[base + offset]; +} + +fn b_value(base: u32) -> DataType { + let offset: u32 = select(params.ne0, 0, params.swapped != 0); + return src0[base + offset]; +} + +#else +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +fn a_value(base: u32) -> DataType { + return src0[base]; +} + +fn b_value(base: u32) -> DataType { + return src1[base]; +} + +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; + let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + dst[i_dst] = op(a_value(i_a), b_value(i_b)); +} diff --git a/src/ggml-webgpu/wgsl-shaders/rope.wgsl b/src/ggml-webgpu/wgsl-shaders/rope.wgsl new file mode 100644 index 00000000..1c874e14 --- /dev/null +++ b/src/ggml-webgpu/wgsl-shaders/rope.wgsl @@ -0,0 +1,224 @@ +enable f16; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_src2: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + n_threads: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + n_dims: u32, + mode: u32, + theta_scale: f32, + attn_factor: f32, + freq_scale: f32, + ext_factor: f32, + corr_dim0: f32, + corr_dim1: f32, + sections0: u32, + sections1: u32, + sections2: u32, + sections3: u32 +}; + +@group(0) @binding(0) +var src0: array; +@group(0) @binding(1) +var src1: array; + +#ifdef INPLACE + +#ifdef FF_FUNC + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var params: Params; + +#else + +@group(0) @binding(2) +var params: Params; + +#endif + +#else + +#ifdef FF_FUNC +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var dst: array; + +@group(0) @binding(4) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#endif +#endif + +#ifdef FF_FUNC +fn freq_factor(i: u32) -> f32 { + return src2[params.offset_src2 + i/2]; +} + +#else +fn freq_factor(i: u32) -> f32 { + return 1.0f; +} +#endif +#ifdef INPLACE +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + src0[i_dst0] = DataType(out0); + src0[i_dst1] = DataType(out1); +} +#else +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + dst[i_dst0] = DataType(out0); + dst[i_dst1] = DataType(out1); +} +#endif + +fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { + let y = (f32(i / 2) - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// returns vector of (cos_theta, sin_theta) +// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row +fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { + var mscale = params.attn_factor; + var theta = params.freq_scale * theta_extrap; + if (params.ext_factor != 0.0f) { + let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; + theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; + mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale); + } + return vec2(cos(theta) * mscale, sin(theta) * mscale); +} + +fn pair_base(i0: u32, div_2: bool) -> u32 { + if (div_2) { + return i0 / 2; + } else { + return i0; + } +} + +fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { + if (is_vision) { + return params.n_dims; + } else if (is_neox || is_mrope) { + return params.n_dims / 2; + } else { + return 1; + } +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + // two elements per n_threads + if (gid.x >= params.n_threads) { + return; + } + + let is_neox = bool(params.mode & 2); + let is_mrope = bool(params.mode & 8); + let is_imrope = params.mode == 40; + let is_vision = params.mode == 24; + + var i = gid.x * 2; // start index for this thread + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + if (i0 >= params.n_dims && !is_vision) { + let i_src = i_src_row + i0; + let i_dst = i_dst_row + i0; + rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); + return; + } + + var theta_base_mult: u32 = 0; + var theta_scale_pwr: u32 = i0 / 2; + if (is_mrope) { + let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3; + let sec_w = params.sections1 + params.sections0; + let sec_e = params.sections2 + sec_w; + let sector = (i0 / 2) % sect_dims; + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * params.sections1) { + theta_base_mult = 1; + } else if (sector % 3 == 2 && sector < 3 * params.sections2) { + theta_base_mult = 2; + } else if (sector % 3 == 0 && sector < 3 * params.sections0) { + theta_base_mult = 0; + } else { + theta_base_mult = 3; + } + } else { + if (sector >= params.sections0 && sector < sec_w) { + theta_base_mult = 1; + if (is_vision) { + theta_scale_pwr = sector - params.sections0; + } + } else if (sector >= sec_w && sector < sec_e) { + theta_base_mult = 2; + if (is_vision) { + theta_scale_pwr = sector - sec_w; + } + } else if (sector >= sec_e) { + if (is_vision) { + theta_scale_pwr = sector - sec_e; + theta_scale_pwr = (i0 / 2) % sec_e; + } + theta_base_mult = 3; + } else if (is_vision) { + theta_scale_pwr = sector; + } + } + } + let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); + let thetas = rope_yarn(theta_base/freq_factor(i0), i0); + + let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision); + let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision); + + let x0 = f32(src0[i_src]); + let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); + rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); + +} diff --git a/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl b/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl new file mode 100644 index 00000000..10edf136 --- /dev/null +++ b/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl @@ -0,0 +1,245 @@ +enable f16; + +#ifdef MASK_F32 +#define MaskType f32 +#endif +#ifdef MASK_F16 +#define MaskType f16 +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_sinks: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of src0/dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + // shape of src1 + ne12: u32, + ne13: u32, + + scale: f32, + max_bias: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) +var src: array; + +#ifdef HAS_MASK +#ifdef HAS_SINK +@group(0) @binding(1) +var mask: array; +@group(0) @binding(2) +var sinks: array; + +#ifdef INPLACE +@group(0) @binding(3) +var params: Params; + +#else +@group(0) @binding(3) +var dst: array; +@group(0) @binding(4) +var params: Params; +#endif + +#else +@group(0) @binding(1) +var mask: array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; +@group(0) @binding(3) +var params: Params; +#endif +#endif + +#else +#ifdef HAS_SINK +@group(0) @binding(1) +var sinks: array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; +@group(0) @binding(3) +var params: Params; +#endif + +#else +#ifdef INPLACE +@group(0) @binding(1) +var params: Params; +#else +@group(0) @binding(1) +var dst: array; +@group(0) @binding(2) +var params: Params; +#endif +#endif +#endif + +#ifdef INPLACE +fn inter_value(i: u32) -> f32 { + return src[i]; +} +fn update(i: u32, val: f32) { + src[i] = val; +} + +#else +fn inter_value(i: u32) -> f32 { + return dst[i]; +} +fn update(i: u32, val: f32) { + dst[i] = val; +} +#endif + +#ifdef HAS_MASK +fn mask_val(i: u32) -> f32 { + return f32(mask[i]); +} + +#else +fn mask_val(i: u32) -> f32 { + return 0.0; +} +#endif + +#ifdef HAS_SINK +fn lower_max_bound(i2: u32) -> f32 { + return sinks[params.offset_sinks + i2]; +} +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val + exp(sinks[params.offset_sinks + i2] - max_val); +} +#else +fn lower_max_bound(i2: u32) -> f32 { + return -1e30; +} +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val; +} +#endif + +const CACHE_SIZE: u32 = 16; +var scratch: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + let head = f32(i2); + let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); + + var cache: array; + + var max_val = lower_max_bound(i2); + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col); + max_val = max(max_val, val); + if (col < CACHE_SIZE) { + cache[col] = val; + } + col += WG_SIZE; + } + + scratch[lid.x] = max_val; + workgroupBarrier(); + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); + } + offset = offset / 2; + workgroupBarrier(); + } + let row_max = scratch[0]; + workgroupBarrier(); + + var sum = 0.0f; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col), + cache[col], col < CACHE_SIZE); + let ex = exp(val - row_max); + sum += ex; + if (col < CACHE_SIZE) { + cache[col] = ex; + } else { + update(i_dst_row + col, ex); + } + col += WG_SIZE; + } + + scratch[lid.x] = sum; + workgroupBarrier(); + offset = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + let row_sum = add_sinks(scratch[0], i2, row_max); + + let sum_recip = 1.0 / row_sum; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); + col += WG_SIZE; + } +} +