#define GGML_WEBGPU_F16_SIZE_BYTES 2
#define GGML_WEBGPU_F32_SIZE_BYTES 4
+#define GGML_WEBGPU_I32_SIZE_BYTES 4
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
#define GGML_WEBGPU_KV_SEQ_PAD 256u
-struct ggml_webgpu_flash_attn_shader_lib_context {
+#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
+
+struct ggml_webgpu_processed_shader {
+ std::string wgsl;
+ std::string variant;
+ void * decisions;
+};
+
+// Same hash combine function as in boost
+template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
+ seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+}
+
+/** FlashAttention */
+
+struct ggml_webgpu_flash_attn_pipeline_key {
ggml_type kv_type;
uint32_t head_dim_qk;
uint32_t head_dim_v;
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
- uint32_t sg_mat_m;
- uint32_t sg_mat_n;
- uint32_t sg_mat_k;
- size_t wg_mem_limit_bytes;
- uint32_t max_subgroup_size;
+
+ bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
+ return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
+ kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
+ uses_logit_softcap == other.uses_logit_softcap;
+ }
+};
+
+struct ggml_webgpu_flash_attn_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.kv_type);
+ ggml_webgpu_hash_combine(seed, key.head_dim_qk);
+ ggml_webgpu_hash_combine(seed, key.head_dim_v);
+ ggml_webgpu_hash_combine(seed, key.kv_direct);
+ ggml_webgpu_hash_combine(seed, key.has_mask);
+ ggml_webgpu_hash_combine(seed, key.has_sinks);
+ ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_flash_attn_shader_lib_context {
+ ggml_webgpu_flash_attn_pipeline_key key;
+ uint32_t sg_mat_m;
+ uint32_t sg_mat_n;
+ uint32_t sg_mat_k;
+ size_t wg_mem_limit_bytes;
+ uint32_t max_subgroup_size;
};
struct ggml_webgpu_flash_attn_shader_decisions {
uint32_t wg_size = 0;
};
-struct ggml_webgpu_processed_shader {
- std::string wgsl;
- std::string variant;
- ggml_webgpu_flash_attn_shader_decisions decisions;
-};
-
// This is exposed because it's necessary in supports_op
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t kv_tile,
}
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
- const size_t limit_bytes = context.wg_mem_limit_bytes;
- const size_t q_tile = context.sg_mat_m;
- const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
- 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
+ const size_t limit_bytes = context.wg_mem_limit_bytes;
+ const size_t q_tile = context.sg_mat_m;
+ const size_t base_q_bytes =
+ (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
+ 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
- if (!context.kv_direct) {
- bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
+ if (!context.key.kv_direct) {
+ bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
}
- if (context.has_mask) {
+ if (context.key.has_mask) {
bytes_per_kv += q_tile;
}
bytes_per_kv += q_tile;
std::vector<std::string> defines;
std::string variant = "flash_attn";
- switch (context.kv_type) {
+ switch (context.key.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
- variant += std::string("_") + ggml_type_name(context.kv_type);
+ variant += std::string("_") + ggml_type_name(context.key.kv_type);
- if (context.has_mask) {
+ if (context.key.has_mask) {
defines.push_back("MASK");
variant += "_mask";
}
- if (context.has_sinks) {
+ if (context.key.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
- if (context.uses_logit_softcap) {
+ if (context.key.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
- if (context.kv_direct) {
+ if (context.key.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
- defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
- variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
-
- defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
- variant += std::string("_hsv") + std::to_string(context.head_dim_v);
+ defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
+ variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
+ defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
+ variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
// For now these are not part of the variant name
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
- if (context.kv_direct) {
+ if (context.key.kv_direct) {
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
// Avoids having to use bounds-checks and decreasing performance for direct KV loads
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
- result.wgsl = preprocessor.preprocess(shader_src, defines);
- result.variant = variant;
- result.decisions.q_tile = q_tile;
- result.decisions.kv_tile = kv_tile;
- result.decisions.wg_size = wg_size;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions();
+ decisions->q_tile = q_tile;
+ decisions->kv_tile = kv_tile;
+ decisions->wg_size = wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+/** Generic **/
+
+struct ggml_webgpu_generic_shader_lib_context {
+ int vec4;
+ uint32_t max_wg_size;
+};
+
+struct ggml_webgpu_generic_shader_decisions {
+ uint32_t wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_generic_shader_lib_context & context,
+ const std::string & base_variant) {
+ std::vector<std::string> defines;
+ std::string variant = base_variant;
+
+ if (context.vec4) {
+ defines.push_back("VEC4");
+ variant += "_vec";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ return result;
+}
+
+/** Pad **/
+
+struct ggml_webgpu_pad_pipeline_key {
+ bool circular;
+
+ bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
+};
+
+struct ggml_webgpu_pad_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.circular);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_pad_shader_lib_context {
+ ggml_webgpu_pad_pipeline_key key;
+ uint32_t max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_pad_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "pad";
+
+ if (context.key.circular) {
+ defines.push_back("CIRCULAR");
+ variant += "_circular";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
+ decisions->wg_size = context.max_wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+/** Argsort **/
+
+struct ggml_webgpu_argsort_shader_lib_context {
+ uint32_t max_wg_size;
+ size_t wg_mem_limit_bytes;
+ int32_t order;
+};
+
+struct ggml_webgpu_argsort_shader_decisions {
+ uint32_t wg_size = 0;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_argsort_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "argsort";
+ defines.push_back(std::string("ORDER=") + std::to_string(context.order));
+ variant += std::string("_order") + std::to_string(context.order);
+ uint32_t wg_size = 1;
+ while (wg_size * 2 <= context.max_wg_size &&
+ wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
+ wg_size *= 2;
+ }
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
+ decisions->wg_size = wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_argsort_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "argsort_merge";
+ defines.push_back(std::string("ORDER=") + std::to_string(context.order));
+ variant += std::string("_order") + std::to_string(context.order);
+ uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
+ decisions->wg_size = wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+/** Set Rows **/
+
+struct ggml_webgpu_set_rows_pipeline_key {
+ int dst_type;
+ int vec4;
+ int i64_idx;
+
+ bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
+ return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
+ }
+};
+
+struct ggml_webgpu_set_rows_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.dst_type);
+ ggml_webgpu_hash_combine(seed, key.vec4);
+ ggml_webgpu_hash_combine(seed, key.i64_idx);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_set_rows_shader_lib_context {
+ ggml_webgpu_set_rows_pipeline_key key;
+ uint32_t max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_set_rows_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "set_rows";
+
+ switch (context.key.dst_type) {
+ case GGML_TYPE_F32:
+ defines.push_back("DST_F32");
+ variant += "_dstf32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("DST_F16");
+ variant += "_dstf16";
+ break;
+ default:
+ GGML_ABORT("Unsupported dst type for set_rows shader");
+ }
+
+ if (context.key.vec4) {
+ defines.push_back("VEC4");
+ variant += "_vec";
+ }
+ if (context.key.i64_idx) {
+ defines.push_back("I64_IDX");
+ variant += "_i64idx";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
+ decisions->wg_size = context.max_wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+struct ggml_webgpu_unary_pipeline_key {
+ int type;
+ int op;
+ bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
+ bool inplace;
+
+ bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
+ return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
+ }
+};
+
+struct ggml_webgpu_unary_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.type);
+ ggml_webgpu_hash_combine(seed, key.op);
+ ggml_webgpu_hash_combine(seed, key.is_unary);
+ ggml_webgpu_hash_combine(seed, key.inplace);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_unary_shader_lib_context {
+ ggml_webgpu_unary_pipeline_key key;
+ uint32_t max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_unary_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
+ ggml_op_name((ggml_op) context.key.op);
+ // Operation-specific behavior
+ defines.push_back(variant);
+
+ switch (context.key.type) {
+ case GGML_TYPE_F32:
+ defines.push_back("TYPE_F32");
+ variant += "_f32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("TYPE_F16");
+ variant += "_f16";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for unary shader");
+ }
+
+ if (context.key.inplace) {
+ defines.push_back("INPLACE");
+ variant += "_inplace";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
+ decisions->wg_size = context.max_wg_size;
+ result.decisions = decisions;
return result;
}
struct webgpu_command {
wgpu::CommandBuffer commands;
- webgpu_pool_bufs params_bufs;
+ std::vector<webgpu_pool_bufs> params_bufs;
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_gpu_profile_bufs timestamp_query_bufs;
#endif
};
-struct flash_attn_pipeline_key {
- int q_type;
- int kv_type;
- int dst_type;
- uint32_t head_dim_qk;
- uint32_t head_dim_v;
- bool kv_direct;
- bool has_mask;
- bool has_sinks;
- bool uses_logit_softcap;
-
- bool operator==(const flash_attn_pipeline_key & other) const {
- return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
- head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
- has_mask == other.has_mask && has_sinks == other.has_sinks &&
- uses_logit_softcap == other.uses_logit_softcap;
- }
-};
-
-// Same hash combine function as in boost
-template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
- seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
-}
-
-struct flash_attn_pipeline_key_hash {
- size_t operator()(const flash_attn_pipeline_key & key) const {
- size_t seed = 0;
- ggml_webgpu_hash_combine(seed, key.q_type);
- ggml_webgpu_hash_combine(seed, key.kv_type);
- ggml_webgpu_hash_combine(seed, key.dst_type);
- ggml_webgpu_hash_combine(seed, key.head_dim_qk);
- ggml_webgpu_hash_combine(seed, key.head_dim_v);
- ggml_webgpu_hash_combine(seed, key.kv_direct);
- ggml_webgpu_hash_combine(seed, key.has_mask);
- ggml_webgpu_hash_combine(seed, key.has_sinks);
- ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
- return seed;
- }
-};
-
// All the base objects needed to run operations on a WebGPU device
struct webgpu_context_struct {
wgpu::Instance instance;
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
- std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
+ std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
+ flash_attn_pipelines;
- std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
+ std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
+ std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order (asc/desc)
+ std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order (asc/desc)
+ std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
+ std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
+
+ std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
+ set_rows_pipelines;
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
std::map<int, webgpu_pipeline> scale_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
- std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> unary_pipelines; // unary_op, type, inplace
+ std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
+ unary_pipelines;
+ std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
size_t memset_bytes_per_thread;
for (const auto & command : commands) {
command_buffers.push_back(command.commands);
- params_bufs.push_back(command.params_bufs);
+ params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
if (command.set_rows_error_bufs) {
set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
}
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
}
// Free the staged buffers
- ctx->param_buf_pool.free_bufs({ params_bufs });
+ ctx->param_buf_pool.free_bufs(params_bufs);
});
futures.push_back({ p_f });
return { futures };
}
-static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx,
- webgpu_pipeline & pipeline,
- std::vector<uint32_t> params,
- std::vector<wgpu::BindGroupEntry> bind_group_entries,
- uint32_t wg_x,
- uint32_t wg_y = 1,
- std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
- webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
-
- ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
- uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
- for (size_t i = 0; i < params.size(); i++) {
- _params[i] = params[i];
- };
+static webgpu_command ggml_backend_webgpu_build_multi(
+ webgpu_context & ctx,
+ const std::vector<webgpu_pipeline> & pipelines,
+ const std::vector<std::vector<uint32_t>> & params_list,
+ const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
+ const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
+ const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
+ GGML_ASSERT(pipelines.size() == params_list.size());
+ GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
+ GGML_ASSERT(pipelines.size() == workgroups_list.size());
+
+ std::vector<webgpu_pool_bufs> params_bufs_list;
+ std::vector<wgpu::BindGroup> bind_groups;
+
+ for (size_t i = 0; i < pipelines.size(); i++) {
+ webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
+
+ ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
+ params_bufs.host_buf.GetSize());
+ uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
+ for (size_t j = 0; j < params_list[i].size(); j++) {
+ _params[j] = params_list[i][j];
+ }
+ params_bufs.host_buf.Unmap();
- params_bufs.host_buf.Unmap();
+ std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
+ uint32_t params_binding_num = entries.size();
+ entries.push_back({ .binding = params_binding_num,
+ .buffer = params_bufs.dev_buf,
+ .offset = 0,
+ .size = params_bufs.dev_buf.GetSize() });
- uint32_t params_bufs_binding_num = bind_group_entries.size();
- bind_group_entries.push_back({ .binding = params_bufs_binding_num,
- .buffer = params_bufs.dev_buf,
- .offset = 0,
- .size = params_bufs.dev_buf.GetSize() });
+ wgpu::BindGroupDescriptor bind_group_desc;
+ bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
+ bind_group_desc.entryCount = entries.size();
+ bind_group_desc.entries = entries.data();
+ bind_group_desc.label = pipelines[i].name.c_str();
+ bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));
- wgpu::BindGroupDescriptor bind_group_desc;
- bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0);
- bind_group_desc.entryCount = bind_group_entries.size();
- bind_group_desc.entries = bind_group_entries.data();
- bind_group_desc.label = pipeline.name.c_str();
- wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
+ params_bufs_list.push_back(params_bufs);
+ }
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
- encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
+ for (const auto & params_bufs : params_bufs_list) {
+ encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
+ }
+
+ // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
+ if (set_rows_error_bufs) {
+ encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
+ set_rows_error_bufs->host_buf.GetSize());
+ }
#ifdef GGML_WEBGPU_GPU_PROFILE
- // --- Profiling: GPU timestamp queries ---
- // Allocate a timestamp query buffer (2 timestamps: start/end)
webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
ts_bufs.host_buf.Unmap();
#else
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
#endif
- pass.SetPipeline(pipeline.pipeline);
- pass.SetBindGroup(0, bind_group);
- pass.DispatchWorkgroups(wg_x, wg_y, 1);
+ for (size_t i = 0; i < pipelines.size(); i++) {
+ pass.SetPipeline(pipelines[i].pipeline);
+ pass.SetBindGroup(0, bind_groups[i]);
+ pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1);
+ }
pass.End();
#ifdef GGML_WEBGPU_GPU_PROFILE
- // Resolve the query set into the device buffer
encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
#endif
- // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
- if (set_rows_error_bufs) {
- encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
- set_rows_error_bufs->host_buf.GetSize());
- }
-
wgpu::CommandBuffer commands = encoder.Finish();
webgpu_command result = {};
result.commands = commands;
- result.params_bufs = params_bufs;
+ result.params_bufs = params_bufs_list;
result.set_rows_error_bufs = set_rows_error_bufs;
#ifdef GGML_WEBGPU_GPU_PROFILE
result.timestamp_query_bufs = ts_bufs;
- result.pipeline_name = pipeline.name;
+ // TODO: handle multiple pipeline names
+ result.pipeline_name = pipelines.front().name;
#endif
return result;
}
+static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx,
+ webgpu_pipeline & pipeline,
+ std::vector<uint32_t> params,
+ std::vector<wgpu::BindGroupEntry> bind_group_entries,
+ uint32_t wg_x,
+ uint32_t wg_y = 1,
+ std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
+ return ggml_backend_webgpu_build_multi(ctx,
+ {
+ pipeline
+ },
+ { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
+}
+
static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
wgpu::Buffer & buf,
uint32_t value,
return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x);
}
+static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ const bool circular = ggml_get_op_params_i32(dst, 8) != 0;
+
+ ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular };
+ ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { .key = pipeline_key,
+ .max_wg_size =
+ ctx->limits.maxComputeInvocationsPerWorkgroup };
+
+ webgpu_pipeline pipeline;
+ {
+ // TODO: remove guard once pipeline caches are per-thread
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+ auto it = ctx->pad_pipelines.find(pipeline_key);
+ if (it != ctx->pad_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
+ pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->pad_pipelines.emplace(pipeline_key, pipeline);
+ }
+ }
+
+ ggml_webgpu_generic_shader_decisions decisions =
+ *static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
+
+ const uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+ std::vector<uint32_t> params = {
+ ne,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ // Strides (in elements)
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+ // Shapes
+ (uint32_t) src->ne[0],
+ (uint32_t) src->ne[1],
+ (uint32_t) src->ne[2],
+ (uint32_t) src->ne[3],
+ (uint32_t) dst->ne[0],
+ (uint32_t) dst->ne[1],
+ (uint32_t) dst->ne[2],
+ (uint32_t) dst->ne[3],
+ // Pad sizes
+ (uint32_t) ggml_get_op_params_i32(dst, 0),
+ (uint32_t) ggml_get_op_params_i32(dst, 1),
+ (uint32_t) ggml_get_op_params_i32(dst, 2),
+ (uint32_t) ggml_get_op_params_i32(dst, 3),
+ (uint32_t) ggml_get_op_params_i32(dst, 4),
+ (uint32_t) ggml_get_op_params_i32(dst, 5),
+ (uint32_t) ggml_get_op_params_i32(dst, 6),
+ (uint32_t) ggml_get_op_params_i32(dst, 7),
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+}
+
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * idx,
return std::nullopt;
}
- webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
- if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
- error_bufs.host_buf.Unmap();
+ ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type,
+ .vec4 = src->ne[0] % 4 == 0,
+ .i64_idx = idx->type == GGML_TYPE_I64 };
+
+ ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { .key = key,
+ .max_wg_size =
+ ctx->limits.maxComputeInvocationsPerWorkgroup };
+
+ webgpu_pipeline pipeline;
+ // TODO: remove guard once pipeline caches are per-thread
+ {
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+ auto it = ctx->set_rows_pipelines.find(key);
+ if (it != ctx->set_rows_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
+ pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->set_rows_pipelines.emplace(key, pipeline);
+ }
+ }
+
+ ggml_webgpu_generic_shader_decisions decisions =
+ *static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
+
+ std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
+ if (key.i64_idx) {
+ error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
+ if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
+ error_bufs->host_buf.Unmap();
+ }
}
std::vector<uint32_t> params = {
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
- .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
- { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
- int vectorized = src->ne[0] % 4 == 0;
- webgpu_pipeline pipeline = ctx->set_rows_pipelines[0][vectorized];
- uint32_t threads;
- if (vectorized) {
+ if (key.i64_idx) {
+ entries.push_back(
+ { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
+ }
+
+ uint32_t threads;
+ if (key.vec4) {
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
} else {
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
}
-
- uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE);
-
+ uint32_t wg_x = CEIL_DIV(threads, decisions.wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
}
bool kv_direct =
(K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
- flash_attn_pipeline_key key = {
- .q_type = Q->type,
+ ggml_webgpu_flash_attn_pipeline_key key = {
.kv_type = K->type,
- .dst_type = dst->type,
.head_dim_qk = (uint32_t) Q->ne[0],
.head_dim_v = (uint32_t) V->ne[0],
.kv_direct = kv_direct,
};
webgpu_pipeline pipeline;
- ggml_webgpu_flash_attn_shader_decisions decisions = {};
-
- auto it = ctx->flash_attn_pipelines.find(key);
- if (it != ctx->flash_attn_pipelines.end()) {
- pipeline = it->second;
- decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
- } else {
+ // TODO: remove guard once pipeline caches are per-thread
+ {
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
- it = ctx->flash_attn_pipelines.find(key);
+ auto it = ctx->flash_attn_pipelines.find(key);
if (it != ctx->flash_attn_pipelines.end()) {
pipeline = it->second;
- decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
} else {
- ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type,
- .head_dim_qk = (uint32_t) Q->ne[0],
- .head_dim_v = (uint32_t) V->ne[0],
- .kv_direct = kv_direct,
- .has_mask = static_cast<bool>(has_mask),
- .has_sinks = static_cast<bool>(has_sinks),
- .uses_logit_softcap = logit_softcap != 0.0f,
- .sg_mat_m = ctx->sg_mat_m,
- .sg_mat_n = ctx->sg_mat_n,
- .sg_mat_k = ctx->sg_mat_k,
+ ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .key = key,
+ .sg_mat_m = ctx->sg_mat_m,
+ .sg_mat_n = ctx->sg_mat_n,
+ .sg_mat_k = ctx->sg_mat_k,
.wg_mem_limit_bytes =
ctx->limits.maxComputeWorkgroupStorageSize,
.max_subgroup_size = ctx->max_subgroup_size };
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
- pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
+ pipeline.context = processed.decisions;
ctx->flash_attn_pipelines.emplace(key, pipeline);
- decisions = processed.decisions;
}
}
+ ggml_webgpu_flash_attn_shader_decisions decisions =
+ *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
+
+
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
- uint32_t ne = (uint32_t) ggml_nelements(dst);
- ggml_unary_op unary_op = ggml_get_unary_op(dst);
- uint32_t inplace = ggml_webgpu_tensor_equal(src, dst);
+ bool is_unary = dst->op == GGML_OP_UNARY;
+ bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
+ int op = is_unary ? (int) ggml_get_unary_op(dst) : dst->op;
- std::vector<uint32_t> params = {
- ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
- // Convert byte-strides to element-strides
- (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
- (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
- (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
- (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
- // Logical shapes
- (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
- (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
+ ggml_webgpu_unary_pipeline_key pipeline_key = {
+ .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace
};
+ ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { .key = pipeline_key,
+ .max_wg_size =
+ ctx->limits.maxComputeInvocationsPerWorkgroup };
- switch (unary_op) {
- case GGML_UNARY_OP_XIELU:
- {
- // Get float parameters and reinterpret their bit patterns as uint32_t
- // for passing through the params buffer
- float alpha_n = ggml_get_op_params_f32(dst, 1);
- float alpha_p = ggml_get_op_params_f32(dst, 2);
- float beta = ggml_get_op_params_f32(dst, 3);
- float eps = ggml_get_op_params_f32(dst, 4);
- params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
- params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
- params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
- params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
+ webgpu_pipeline pipeline;
+ {
+ // TODO: remove guard once pipeline caches are per-thread
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+ auto it = ctx->unary_pipelines.find(pipeline_key);
+ if (it != ctx->unary_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
+ pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->unary_pipelines.emplace(pipeline_key, pipeline);
+ }
+ }
+
+ ggml_webgpu_generic_shader_decisions decisions =
+ *static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
+
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+ std::vector<uint32_t> params = { ne,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+ (uint32_t) src->ne[0],
+ (uint32_t) src->ne[1],
+ (uint32_t) src->ne[2] };
+
+ ggml_tensor * effective_src = src;
+ if (is_unary) {
+ ggml_unary_op unary_op = ggml_get_unary_op(dst);
+ switch (unary_op) {
+ case GGML_UNARY_OP_XIELU:
+ {
+ // Get float parameters and reinterpret their bit patterns as uint32_t
+ // for passing through the params buffer
+ float alpha_n = ggml_get_op_params_f32(dst, 1);
+ float alpha_p = ggml_get_op_params_f32(dst, 2);
+ float beta = ggml_get_op_params_f32(dst, 3);
+ float eps = ggml_get_op_params_f32(dst, 4);
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
+ params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
+ params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
+ break;
+ }
+ default:
break;
- }
- default:
- break;
+ }
+ } else if (dst->op == GGML_OP_CLAMP) {
+ float clamp_min = ggml_get_op_params_f32(dst, 0);
+ float clamp_max = ggml_get_op_params_f32(dst, 1);
+ params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_min));
+ params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_max));
+ } else if (dst->op == GGML_OP_FILL) {
+ float fill_val = ggml_get_op_params_f32(dst, 0);
+ params.push_back(*reinterpret_cast<const uint32_t *>(&fill_val));
+ effective_src = dst; // fill simply fills dst
}
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
- .buffer = ggml_webgpu_tensor_buf(src),
- .offset = ggml_webgpu_tensor_align_offset(ctx, src),
- .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ .buffer = ggml_webgpu_tensor_buf(effective_src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) },
};
if (!inplace) {
entries.push_back({ .binding = 1,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
- uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
- return ggml_backend_webgpu_build(ctx, ctx->unary_pipelines[unary_op][dst->type][inplace], params, entries, wg_x);
+ uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_nrows(dst));
}
+static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) src->ne[0] };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
+ .vec4 = src->ne[0] % 4 == 0,
+ .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup,
+ };
+
+ webgpu_pipeline pipeline;
+ {
+ // TODO: remove guard once pipeline caches are per-thread
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+ auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
+ if (it != ctx->argmax_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
+ pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
+ }
+ }
+ uint32_t wg_x = ggml_nelements(dst);
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ bool is_top_k = dst->op == GGML_OP_TOP_K;
+ // ascending order is 0, descending order is 1
+ const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0);
+
+ ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { .max_wg_size =
+ ctx->limits.maxComputeInvocationsPerWorkgroup,
+ .wg_mem_limit_bytes =
+ ctx->limits.maxComputeWorkgroupStorageSize,
+ .order = order };
+
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+ webgpu_pipeline argsort_pipeline;
+ auto it = ctx->argsort_pipelines.find(order);
+ if (it != ctx->argsort_pipelines.end()) {
+ argsort_pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx);
+ argsort_pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ argsort_pipeline.context = processed.decisions;
+ ctx->argsort_pipelines.emplace(order, argsort_pipeline);
+ }
+ ggml_webgpu_argsort_shader_decisions argsort_decisions =
+ *static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context);
+
+ webgpu_pipeline argsort_merge_pipeline;
+ it = ctx->argsort_merge_pipelines.find(order);
+ if (it != ctx->argsort_merge_pipelines.end()) {
+ argsort_merge_pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx);
+ argsort_merge_pipeline =
+ ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ argsort_merge_pipeline.context = processed.decisions;
+ ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline);
+ }
+
+ const uint32_t src_ne0 = (uint32_t) src->ne[0];
+ const uint32_t nrows = (uint32_t) ggml_nrows(src);
+ const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions.wg_size);
+ const uint32_t block_size =
+ is_top_k ? std::min(argsort_decisions.wg_size, (uint32_t) dst->ne[0]) : argsort_decisions.wg_size;
+ uint32_t out_ne0 = src_ne0;
+ if (is_top_k) {
+ if (npr > 1) {
+ const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions.wg_size;
+ out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size);
+ } else {
+ out_ne0 = block_size;
+ }
+ }
+
+ uint32_t merge_len = block_size;
+ uint32_t merge_passes = 0;
+ while (merge_len < out_ne0) {
+ merge_len <<= 1;
+ merge_passes++;
+ }
+
+ const bool start_in_tmp = (merge_passes % 2) == 1;
+
+ const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
+ const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
+ const size_t tmp_offset = ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->limits.minStorageBufferOffsetAlignment);
+ const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
+ const size_t dst_binding_size =
+ ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);
+
+ const uint32_t offset_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type));
+ const uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type));
+ const uint32_t offset_tmp = 0;
+ const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type));
+ const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type));
+ const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type));
+ const uint32_t stride_idx1 = out_ne0;
+ const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1];
+ const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2];
+
+ std::vector<webgpu_pipeline> pipelines;
+ std::vector<std::vector<uint32_t>> params_list;
+ std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
+ std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
+
+ const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst;
+ const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
+ const size_t init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size;
+
+ std::vector<uint32_t> init_params = {
+ offset_src, init_offset, stride_src1, stride_src2, stride_src3, stride_idx1,
+ stride_idx2, stride_idx3, src_ne0, (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0,
+ block_size, npr, nrows
+ };
+
+ const uint32_t total_wg_init = npr * nrows;
+ const uint32_t max_wg = ctx->limits.maxComputeWorkgroupsPerDimension;
+ const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
+ const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
+ std::vector<wgpu::BindGroupEntry> init_entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size }
+ };
+
+ pipelines.push_back(argsort_pipeline);
+ params_list.push_back(std::move(init_params));
+ entries_list.push_back(std::move(init_entries));
+ workgroups_list.push_back({ wg_x_init, wg_y_init });
+
+ if (merge_passes == 0) {
+ return ggml_backend_webgpu_build_multi(ctx, pipelines, params_list, entries_list, workgroups_list);
+ }
+
+ bool in_is_tmp = start_in_tmp;
+ uint32_t len = block_size;
+ while (len < out_ne0) {
+ const uint32_t nm = CEIL_DIV(out_ne0, 2 * len);
+
+ const bool out_is_tmp = !in_is_tmp;
+ const uint32_t offset_in = in_is_tmp ? offset_tmp : offset_dst;
+ const uint32_t offset_out = out_is_tmp ? offset_tmp : offset_dst;
+ const size_t align_in = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
+ const size_t align_out = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
+ const size_t size_in = in_is_tmp ? tmp_binding_size : dst_binding_size;
+ const size_t size_out = out_is_tmp ? tmp_binding_size : dst_binding_size;
+ const uint32_t top_k_out = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0;
+ const uint32_t stride_out1 = top_k_out;
+ const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1];
+ const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2];
+
+ std::vector<uint32_t> merge_params = { offset_src,
+ offset_in,
+ offset_out,
+ stride_src1,
+ stride_src2,
+ stride_src3,
+ stride_idx1,
+ stride_idx2,
+ stride_idx3,
+ stride_out1,
+ stride_out2,
+ stride_out3,
+ out_ne0,
+ (uint32_t) src->ne[1],
+ (uint32_t) src->ne[2],
+ top_k_out,
+ len,
+ nm,
+ nrows };
+
+ std::vector<wgpu::BindGroupEntry> merge_entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in },
+ { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out }
+ };
+
+ const uint32_t total_wg_merge = nm * nrows;
+ const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
+ const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
+ workgroups_list.push_back({ wg_x_merge, wg_y_merge });
+ pipelines.push_back(argsort_merge_pipeline);
+ params_list.push_back(std::move(merge_params));
+ entries_list.push_back(std::move(merge_entries));
+
+ len <<= 1;
+ in_is_tmp = !in_is_tmp;
+ }
+
+ return ggml_backend_webgpu_build_multi(ctx, pipelines, params_list, entries_list, workgroups_list);
+}
+
+static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) src->ne[0] };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
+ .vec4 = false,
+ .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup,
+ };
+ webgpu_pipeline pipeline;
+ // TODO: remove guard once pipeline caches are per-thread
+ {
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+ auto it = ctx->cumsum_pipelines.find(1);
+ if (it != ctx->cumsum_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
+ pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ ctx->cumsum_pipelines.emplace(1, pipeline);
+ }
+ }
+ uint32_t wg_x = ggml_nrows(dst);
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ bool total_sum = dst->op == GGML_OP_SUM;
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+ total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+ total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+ total_sum ? static_cast<uint32_t>(ggml_nelements(src)) : (uint32_t) src->ne[0],
+ total_sum ? 1 : (uint32_t) src->ne[1],
+ total_sum ? 1 : (uint32_t) src->ne[2] };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
+ .vec4 = false,
+ .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup,
+ };
+
+ webgpu_pipeline pipeline;
+ {
+ // TODO: remove guard once pipeline caches are per-thread
+ std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+ auto it = ctx->sum_rows_pipelines.find(1);
+ if (it != ctx->sum_rows_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
+ pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ ctx->sum_rows_pipelines.emplace(1, pipeline);
+ }
+ }
+ uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+}
+
// Returns the encoded command, or std::nullopt if the operation is a no-op
static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
if (ggml_is_empty(node)) {
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
case GGML_OP_UNARY:
return ggml_webgpu_unary_op(ctx, src0, node);
+ case GGML_OP_CLAMP:
+ return ggml_webgpu_unary_op(ctx, src0, node);
+ case GGML_OP_FILL:
+ return ggml_webgpu_unary_op(ctx, src0, node);
+ case GGML_OP_LOG:
+ return ggml_webgpu_unary_op(ctx, src0, node);
+ case GGML_OP_PAD:
+ return ggml_webgpu_pad(ctx, src0, node);
+ case GGML_OP_ARGMAX:
+ return ggml_webgpu_argmax(ctx, src0, node);
+ case GGML_OP_ARGSORT:
+ return ggml_webgpu_argsort(ctx, src0, node);
+ case GGML_OP_TOP_K:
+ // we reuse the same argsort implementation for top_k
+ return ggml_webgpu_argsort(ctx, src0, node);
+ case GGML_OP_CUMSUM:
+ return ggml_webgpu_cumsum(ctx, src0, node);
+ case GGML_OP_SUM:
+ case GGML_OP_SUM_ROWS:
+ return ggml_webgpu_sum_rows(ctx, src0, node);
default:
return std::nullopt;
}
return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
}
+static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
+ const ggml_tensor * tensor) {
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
+ size_t res = ggml_nbytes(tensor);
+ switch (tensor->op) {
+ case GGML_OP_ARGSORT:
+ res = ROUNDUP_POW2(res * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment,
+ WEBGPU_STORAGE_BUF_BINDING_MULT);
+ break;
+ case GGML_OP_TOP_K:
+ {
+ const ggml_tensor * src0 = tensor->src[0];
+ if (src0) {
+ const size_t full = sizeof(int32_t) * ggml_nelements(src0);
+ res = ROUNDUP_POW2(full * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment,
+ WEBGPU_STORAGE_BUF_BINDING_MULT);
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ return res;
+}
+
/* End GGML Backend Buffer Type Interface */
/* GGML Backend Device Interface */
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
// TODO: for now, return maxBufferSize as both free and total memory
// Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
- uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize;
+ uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize;
// If we're on a 32-bit system, clamp to UINTPTR_MAX
#if UINTPTR_MAX < UINT64_MAX
uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
}
-static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
- webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline(
- webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
- webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
-}
-
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
}
-static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
- // ABS
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f32, "abs_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f16, "abs_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f32, "abs_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f16, "abs_inplace_f16", constants);
-
- // SGN
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f32, "sgn_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f16, "sgn_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f32, "sgn_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f16, "sgn_inplace_f16", constants);
-
- // NEG
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f32, "neg_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f16, "neg_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f32, "neg_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f16, "neg_inplace_f16", constants);
-
- // STEP
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f32, "step_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f16, "step_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f32, "step_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f16, "step_inplace_f16", constants);
-
- // TANH
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f32, "tanh_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f16, "tanh_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f32, "tanh_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f16, "tanh_inplace_f16", constants);
-
- // ELU
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f32, "elu_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f16, "elu_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f32, "elu_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f16, "elu_inplace_f16", constants);
-
- // RELU
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f32, "relu_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f16, "relu_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f32, "relu_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f16, "relu_inplace_f16", constants);
-
- // SIGMOID
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f32, "sigmoid_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f16, "sigmoid_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f32, "sigmoid_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f16, "sigmoid_inplace_f16", constants);
-
- // GELU
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f32, "gelu_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f16, "gelu_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f32, "gelu_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f16, "gelu_inplace_f16", constants);
-
- // GELU_QUICK
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f32, "gelu_quick_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f16, "gelu_quick_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->device, wgsl_gelu_quick_inplace_f32, "gelu_quick_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->device, wgsl_gelu_quick_inplace_f16, "gelu_quick_inplace_f16", constants);
-
- // SILU
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f32, "silu_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f16, "silu_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f32, "silu_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f16, "silu_inplace_f16", constants);
-
- // HARDSWISH
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f32, "hardswish_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f16, "hardswish_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f32, "hardswish_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f16, "hardswish_inplace_f16", constants);
-
- // HARDSIGMOID
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->device, wgsl_hardsigmoid_inplace_f32, "hardsigmoid_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->device, wgsl_hardsigmoid_inplace_f16, "hardsigmoid_inplace_f16", constants);
-
- // EXP
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f32, "exp_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f16, "exp_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f32, "exp_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f16, "exp_inplace_f16", constants);
-
- // GELU_ERF
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f32, "gelu_erf_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f16, "gelu_erf_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f32, "gelu_erf_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f16, "gelu_erf_inplace_f16", constants);
-
- // XIELU
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f32, "xielu_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f16, "xielu_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants);
-
- // CEIL
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f32, "ceil_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f16, "ceil_f16", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f32, "ceil_inplace_f32", constants);
- webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f16, "ceil_inplace_f16", constants);
-}
-
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
/* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
/* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
- /* .is_host = */ NULL, // defaults to false
+ /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size,
+ /* .is_host = */ NULL, // defaults to false
},
/* .device = */
dev,
break;
case GGML_OP_CPY:
case GGML_OP_CONT:
- supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
- (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
+ (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
break;
case GGML_OP_SET_ROWS:
- supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
+ supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
+ (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
break;
case GGML_OP_GET_ROWS:
- if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
- ggml_webgpu_supported_qtype(src0->type)) {
+ if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
supports_op = (op->type == GGML_TYPE_F32);
+ } else if (src0->type == GGML_TYPE_I32) {
+ supports_op = op->type == GGML_TYPE_I32;
}
break;
case GGML_OP_MUL_MAT:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_GELU_ERF:
- case GGML_UNARY_OP_XIELU:
+ case GGML_UNARY_OP_SOFTPLUS:
+ case GGML_UNARY_OP_EXPM1:
+ case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
- supports_op = supports_op =
+ case GGML_UNARY_OP_ROUND:
+ case GGML_UNARY_OP_TRUNC:
+ case GGML_UNARY_OP_XIELU:
+ supports_op =
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
default:
}
}
break;
-
+ case GGML_OP_CLAMP:
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+ break;
+ case GGML_OP_FILL:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_LOG:
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+ break;
+ case GGML_OP_PAD:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_ARGMAX:
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_ARGSORT:
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
+ break;
+ case GGML_OP_TOP_K:
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
+ break;
+ case GGML_OP_CUMSUM:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type;
+ break;
+ case GGML_OP_SUM:
+ case GGML_OP_SUM_ROWS:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
+ break;
default:
break;
}
ggml_webgpu_init_memset_pipeline(ctx);
ggml_webgpu_init_mul_mat_pipeline(ctx);
- ggml_webgpu_init_set_rows_pipeline(ctx);
ggml_webgpu_init_get_rows_pipeline(ctx);
ggml_webgpu_init_cpy_pipeline(ctx);
ggml_webgpu_init_add_pipeline(ctx);
ggml_webgpu_init_glu_pipeline(ctx);
ggml_webgpu_init_scale_pipeline(ctx);
ggml_webgpu_init_soft_max_pipeline(ctx);
- ggml_webgpu_init_unary_pipeline(ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
+++ /dev/null
-#define(REPL_TEMPLATES)
-
-{
- "XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);",
- "ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);",
- "SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);",
- "NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];",
- "STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));",
- "TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
- "RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);",
- "ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);",
- "HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
- "SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));",
- "SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));",
- "EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);",
- "HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
- "GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
- "GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
- "GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
- "CEIL_FUNC": "{{MUTATE}}[dst_i] = ceil(src[src_i]);"
-}
-
-#end(REPL_TEMPLATES)
-
-#define(VARIANTS)
-
-[
- {
- "SHADER_NAME": "abs_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "abs_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "abs_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "abs_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "sgn_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "sgn_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "sgn_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "sgn_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "neg_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "neg_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "neg_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "neg_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "step_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "step_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "step_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "step_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "tanh_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "tanh_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "tanh_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "tanh_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "elu_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "elu_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "elu_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "elu_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "relu_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "relu_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "relu_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "relu_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "sigmoid_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "sigmoid_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "sigmoid_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "sigmoid_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "silu_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "silu_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "silu_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "silu_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "exp_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "exp_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "exp_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "exp_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "hardsigmoid_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "hardsigmoid_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "hardsigmoid_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "hardsigmoid_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "hardswish_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "hardswish_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "hardswish_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "hardswish_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "gelu_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "gelu_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "gelu_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "gelu_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "gelu_quick_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "gelu_quick_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "gelu_quick_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "gelu_quick_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "xielu_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "xielu_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "xielu_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "xielu_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "gelu_erf_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "gelu_erf_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "gelu_erf_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "gelu_erf_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
-
- {
- "SHADER_NAME": "ceil_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "ceil_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "ceil_inplace_f32",
- "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "ceil_inplace_f16",
- "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
- "DECLS": ["INPLACE"]
- }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(INPLACE)
-
-@group(0) @binding(1)
-var<uniform> params: Params;
-
-#enddecl(INPLACE)
-
-#decl(NOT_INPLACE)
-
-@group(0) @binding(1)
-var<storage, read_write> dst: array<{{TYPE}}>;
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-#enddecl(NOT_INPLACE)
-
-#end(DECLS)
-
-#define(SHADER)
-
-enable f16;
-
-fn update(dst_i: u32, src_i: u32) {
- {{FUNC}}
-}
-
-@group(0) @binding(0)
-var<storage, read_write> src: array<{{TYPE}}>;
-
-DECLS
-
-struct Params {
- ne: u32, // total number of elements
- offset_src: u32, // in elements
- offset_dst: u32, // in elements
-
- // Strides (in elements) — may be permuted
- stride_src0: u32,
- stride_src1: u32,
- stride_src2: u32,
- stride_src3: u32,
-
- stride_dst0: u32,
- stride_dst1: u32,
- stride_dst2: u32,
- stride_dst3: u32,
-
- // Logical shapes
- src_ne0: u32,
- src_ne1: u32,
- src_ne2: u32,
-
- dst_ne0: u32,
- dst_ne1: u32,
- dst_ne2: u32,
-
- {{EXT_PARAMS}}
-};
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
- if (gid.x >= params.ne) {
- return;
- }
-
- var i = gid.x;
- let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
- i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
- let i2 = i / (params.src_ne1 * params.src_ne0);
- i = i % (params.src_ne1 * params.src_ne0);
- let i1 = i / params.src_ne0;
- let i0 = i % params.src_ne0;
-
- var j = gid.x;
- let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
- j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
- let j2 = j / (params.dst_ne1 * params.dst_ne0);
- j = j % (params.dst_ne1 * params.dst_ne0);
- let j1 = j / params.dst_ne0;
- let j0 = j % params.dst_ne0;
-
- let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
- i2 * params.stride_src2 + i3 * params.stride_src3;
-
- let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
- j2 * params.stride_dst2 + j3 * params.stride_dst3;
-
-
- update(params.offset_dst + dst_idx, params.offset_src + src_idx);
-}
-
-#end(SHADER)
-