From: Reese Levine Date: Wed, 25 Feb 2026 07:33:32 +0000 (+0200) Subject: ggml webgpu: shader library organization (llama/19530) X-Git-Tag: upstream/1.8.4~127 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=fc7a78f4d8662f8d970629622f8a4f15bd90719b;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp ggml webgpu: shader library organization (llama/19530) * Basic JIT compilation for mul_mat, get_rows, and scale (ggml/17) * scale jit working * preliminary working jit for getrows and mulmat, needs refining * simplified mul_mat preprocessing switch statement * get_rows fixes, mul_mat refinement * formatted + last edits * removed some extraneous prints * fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish * small fix * some changes, working * get_rows and mul_mat jit fixed and working * Update formatting * formatting * Add header --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine * Start work on all-encompassing shader library * refactor argmax, set_rows * Refactor all but flashattention, mat mul * flashattention and matrix multiplication moved to new format * clean up preprocessing * Formatting * remove duplicate constants * Split large shaders into multiple static strings --------- Co-authored-by: neha-ha --- diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 63f797f1..0d5a818d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1,11 +1,16 @@ #ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP +#include "ggml-wgsl-shaders.hpp" #include "ggml.h" #include "pre_wgsl.hpp" +#include + +#include #include #include +#include #include #define GGML_WEBGPU_F16_SIZE_BYTES 2 @@ -18,17 +23,203 @@ #define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u -struct ggml_webgpu_processed_shader { - std::string wgsl; - std::string variant; - std::shared_ptr decisions; -}; +// Matrix multiplication parameters + +// Register tiling parameters +#define WEBGPU_MUL_MAT_TILE_M 8 +#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_TILE_K 32 + +// Subgroup matrix parameters +// The number of subgroups in the M dimension +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +// The number of subgroups in the N dimension +#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +// The number of subgroup matrices each subgroup accumulates over +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 + +// Matrix-vector multiplication parameters +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +// Must be multiple of 4 to work with vectorized paths, and must divide +// mul_mat_vec wg size +#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_TILE_K 256 + +// default size for legacy matrix multiplication +#define WEBGPU_MUL_MAT_WG_SIZE 256 // Same hash combine function as in boost template inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +struct ggml_webgpu_shader_lib_context { + ggml_tensor * src0; + ggml_tensor * src1; + ggml_tensor * src2; + ggml_tensor * src3; + ggml_tensor * src4; + ggml_tensor * dst; + + uint32_t max_wg_size; + size_t wg_mem_limit_bytes = 0; + bool inplace = false; + bool overlap = false; + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + uint32_t max_subgroup_size = 0; +}; + +struct webgpu_pipeline { + wgpu::ComputePipeline pipeline; + std::string name; + std::shared_ptr context = nullptr; +}; + +struct ggml_webgpu_generic_shader_decisions { + uint32_t wg_size = 0; +}; + +/** Argsort **/ + +struct ggml_webgpu_argsort_shader_lib_context { + uint32_t max_wg_size; + size_t wg_mem_limit_bytes; + int32_t order; +}; + +/** Set Rows **/ + +struct ggml_webgpu_set_rows_pipeline_key { + int dst_type; + int vec4; + int i64_idx; + + bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { + return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + } +}; + +struct ggml_webgpu_set_rows_pipeline_key_hash { + size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.vec4); + ggml_webgpu_hash_combine(seed, key.i64_idx); + return seed; + } +}; + +struct ggml_webgpu_set_rows_shader_decisions { + bool vec4; + bool i64_idx; + uint32_t wg_size; +}; + +/** Get Rows **/ + +struct ggml_webgpu_get_rows_pipeline_key { + ggml_type src_type; + int vectorized; + + bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const { + return src_type == other.src_type && vectorized == other.vectorized; + } +}; + +struct ggml_webgpu_get_rows_pipeline_key_hash { + size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + +/** Pad **/ +struct ggml_webgpu_pad_pipeline_key { + bool circular; + + bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; } +}; + +struct ggml_webgpu_pad_pipeline_key_hash { + size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.circular); + return seed; + } +}; + +/** Scale **/ + +struct ggml_webgpu_scale_pipeline_key { + int inplace; + + bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; } +}; + +struct ggml_webgpu_scale_pipeline_key_hash { + size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + +/** Binary **/ + +struct ggml_webgpu_binary_pipeline_key { + int type; + int op; + bool inplace; + bool overlap; + + bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + } +}; + +struct ggml_webgpu_binary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); + return seed; + } +}; + +/** Unary **/ + +struct ggml_webgpu_unary_pipeline_key { + int type; + int op; + bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella + bool inplace; + + bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { + return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; + } +}; + +struct ggml_webgpu_unary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.is_unary); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** FlashAttention */ struct ggml_webgpu_flash_attn_pipeline_key { @@ -100,439 +291,941 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } -static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); - } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; -} - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn"; - - switch (context.key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(context.key.kv_type); - - if (context.key.has_mask) { - defines.push_back("MASK"); - variant += "_mask"; - } - if (context.key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (context.key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - - if (context.key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); - // For now these are not part of the variant name - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - - // Add chosen Q/KV tile sizes - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - // Avoids having to use bounds-checks and decreasing performance for direct KV loads - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; - } - } - - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); - - // workgroup size - uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} +/** Matrix Multiplication **/ -/** Generic **/ +struct ggml_webgpu_legacy_mul_mat_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; -struct ggml_webgpu_generic_shader_lib_context { - int vec4; - uint32_t max_wg_size; + bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type; + } }; -struct ggml_webgpu_generic_shader_decisions { - uint32_t wg_size; +struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + return seed; + } }; -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_generic_shader_lib_context & context, - const std::string & base_variant) { - std::vector defines; - std::string variant = base_variant; +struct ggml_webgpu_mul_mat_vec_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; - if (context.vec4) { - defines.push_back("VEC4"); - variant += "_vec"; + bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized; } +}; - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} +struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; -/** Pad **/ +struct ggml_webgpu_mul_mat_vec_shader_decisions { + uint32_t wg_size; + uint32_t tile_k; + uint32_t outputs_per_wg; + uint32_t vec_size; +}; -struct ggml_webgpu_pad_pipeline_key { - bool circular; +struct ggml_webgpu_mul_mat_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; + int use_subgroup_matrix; - bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; } + bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + use_subgroup_matrix == other.use_subgroup_matrix; + } }; -struct ggml_webgpu_pad_pipeline_key_hash { - size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const { +struct ggml_webgpu_mul_mat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.circular); + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix); return seed; } }; -struct ggml_webgpu_pad_shader_lib_context { - ggml_webgpu_pad_pipeline_key key; - uint32_t max_wg_size; +struct ggml_webgpu_mul_mat_shader_decisions { + uint32_t tile_k; + uint32_t wg_size_m; + uint32_t wg_size_n; + uint32_t wg_size; + uint32_t outputs_per_wg; + int use_subgroup_matrix; + + uint32_t tile_m; + uint32_t tile_n; + + // Subgroup matrix parameters + uint32_t subgroup_m; + uint32_t subgroup_n; + uint32_t subgroup_matrix_m; + uint32_t subgroup_matrix_n; + + uint32_t mul_mat_wg_size; }; -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_pad_shader_lib_context & context) { - std::vector defines; - std::string variant = "pad"; +class ggml_webgpu_shader_lib { + wgpu::Device device; + pre_wgsl::Preprocessor preprocessor; + + std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet + std::unordered_map argmax_pipelines; // key is vec4 + std::unordered_map argsort_pipelines; // key is order + std::unordered_map argsort_merge_pipelines; // key is order + std::unordered_map cumsum_pipelines; // key is fixed, no variants yet + std::unordered_map + get_rows_pipelines; // src_type, vectorized + std::unordered_map + unary_pipelines; // type/op/inplace + std::unordered_map + scale_pipelines; // inplace + std::unordered_map + pad_pipelines; // circular/non-circular + std::unordered_map + binary_pipelines; // type/op/inplace/overlap + std::unordered_map + flash_attn_pipelines; + std::unordered_map + mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec) + std::unordered_map + mul_mat_vec_pipelines; // fast mat-vec (n==1) + std::unordered_map + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + + std::unordered_map + set_rows_pipelines; + + public: + ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } + + webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = sum_rows_pipelines.find(1); + if (it != sum_rows_pipelines.end()) { + return it->second; + } + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - if (context.key.circular) { - defines.push_back("CIRCULAR"); - variant += "_circular"; + auto processed = preprocessor.preprocess(wgsl_sum_rows, defines); + sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows"); + return sum_rows_pipelines[1]; } - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool vec4 = context.src0->ne[0] % 4 == 0; - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} + auto it = argmax_pipelines.find(vec4); + if (it != argmax_pipelines.end()) { + return it->second; + } + std::string variant = "argmax"; + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + if (vec4) { + defines.push_back("VEC4"); + variant += "_vec4"; + } -/** Argsort **/ + auto processed = preprocessor.preprocess(wgsl_argmax, defines); + argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant); + return argmax_pipelines.at(vec4); + } -struct ggml_webgpu_argsort_shader_lib_context { - uint32_t max_wg_size; - size_t wg_mem_limit_bytes; - int32_t order; -}; + webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type, + .vec4 = context.src0->ne[0] % 4 == 0, + .i64_idx = context.src1->type == GGML_TYPE_I64 }; -struct ggml_webgpu_argsort_shader_decisions { - uint32_t wg_size = 0; -}; + auto it = set_rows_pipelines.find(key); + if (it != set_rows_pipelines.end()) { + return it->second; + } -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_argsort_shader_lib_context & context) { - std::vector defines; - std::string variant = "argsort"; - defines.push_back(std::string("ORDER=") + std::to_string(context.order)); - variant += std::string("_order") + std::to_string(context.order); - uint32_t wg_size = 1; - while (wg_size * 2 <= context.max_wg_size && - wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) { - wg_size *= 2; - } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} + std::vector defines; + std::string variant = "set_rows"; + + switch (context.dst->type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_dstf32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_dstf16"; + break; + default: + GGML_ABORT("Unsupported dst type for set_rows shader"); + } -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_argsort_shader_lib_context & context) { - std::vector defines; - std::string variant = "argsort_merge"; - defines.push_back(std::string("ORDER=") + std::to_string(context.order)); - variant += std::string("_order") + std::to_string(context.order); - uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} + if (key.vec4) { + defines.push_back("VEC4"); + variant += "_vec4"; + } + if (key.i64_idx) { + defines.push_back("I64_IDX"); + variant += "_i64idx"; + } -/** Set Rows **/ + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); -struct ggml_webgpu_set_rows_pipeline_key { - int dst_type; - int vec4; - int i64_idx; + auto processed = preprocessor.preprocess(wgsl_set_rows, defines); + auto decisions = std::make_shared(); + decisions->vec4 = key.vec4; + decisions->i64_idx = key.i64_idx; + decisions->wg_size = context.max_wg_size; + set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + set_rows_pipelines[key].context = decisions; + return set_rows_pipelines[key]; + } - bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { - return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = cumsum_pipelines.find(1); + if (it != cumsum_pipelines.end()) { + return it->second; + } + + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_cumsum, defines); + cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum"); + return cumsum_pipelines[1]; } -}; -struct ggml_webgpu_set_rows_pipeline_key_hash { - size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.dst_type); - ggml_webgpu_hash_combine(seed, key.vec4); - ggml_webgpu_hash_combine(seed, key.i64_idx); - return seed; + webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool is_top_k = context.dst->op == GGML_OP_TOP_K; + // ascending order is 0, descending order is 1 + const int32_t order = + is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0); + + auto it = argsort_pipelines.find(order); + if (it != argsort_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "argsort"; + defines.push_back(std::string("ORDER=") + std::to_string(order)); + variant += std::string("_order") + std::to_string(order); + uint32_t wg_size = 1; + while (wg_size * 2 <= context.max_wg_size && + wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) { + wg_size *= 2; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + auto processed = preprocessor.preprocess(wgsl_argsort, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant); + argsort_pipelines[order].context = decisions; + return argsort_pipelines[order]; } -}; -struct ggml_webgpu_set_rows_shader_lib_context { - ggml_webgpu_set_rows_pipeline_key key; - uint32_t max_wg_size; -}; + webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool is_top_k = context.dst->op == GGML_OP_TOP_K; + // ascending order is 0, descending order is 1 + const int32_t order = + is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0); -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_set_rows_shader_lib_context & context) { - std::vector defines; - std::string variant = "set_rows"; - - switch (context.key.dst_type) { - case GGML_TYPE_F32: - defines.push_back("DST_F32"); - variant += "_dstf32"; - break; - case GGML_TYPE_F16: - defines.push_back("DST_F16"); - variant += "_dstf16"; - break; - default: - GGML_ABORT("Unsupported dst type for set_rows shader"); - } - - if (context.key.vec4) { - defines.push_back("VEC4"); - variant += "_vec"; - } - if (context.key.i64_idx) { - defines.push_back("I64_IDX"); - variant += "_i64idx"; - } - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} + auto it = argsort_merge_pipelines.find(order); + if (it != argsort_merge_pipelines.end()) { + return it->second; + } -struct ggml_webgpu_unary_pipeline_key { - int type; - int op; - bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella - bool inplace; + std::vector defines; + std::string variant = "argsort_merge"; + defines.push_back(std::string("ORDER=") + std::to_string(order)); + variant += std::string("_order") + std::to_string(order); + uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { - return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; + auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines); + argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant); + return argsort_merge_pipelines[order]; } -}; -struct ggml_webgpu_unary_pipeline_key_hash { - size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.type); - ggml_webgpu_hash_combine(seed, key.op); - ggml_webgpu_hash_combine(seed, key.is_unary); - ggml_webgpu_hash_combine(seed, key.inplace); - return seed; + webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; + ggml_webgpu_get_rows_pipeline_key key = { + .src_type = context.src0->type, + .vectorized = (int) vectorized, + }; + + auto it = get_rows_pipelines.find(key); + if (it != get_rows_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "get_rows"; + + const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type); + const char * type_str = type_traits->type_name; + + switch (key.src_type) { + case GGML_TYPE_F32: + if (key.vectorized) { + defines.push_back("F32_VEC"); + defines.push_back("SRC_TYPE=vec4"); + defines.push_back("DST_TYPE=vec4"); + defines.push_back("BLOCK_SIZE=4u"); + } else { + defines.push_back("F32"); + defines.push_back("SRC_TYPE=f32"); + defines.push_back("DST_TYPE=f32"); + defines.push_back("BLOCK_SIZE=1u"); + } + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("F16"); + defines.push_back("SRC_TYPE=f16"); + defines.push_back("DST_TYPE=f32"); + defines.push_back("BLOCK_SIZE=1u"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("I32"); + defines.push_back("SRC_TYPE=i32"); + defines.push_back("DST_TYPE=i32"); + defines.push_back("BLOCK_SIZE=1u"); + variant += "_i32"; + break; + default: + { + std::string type_upper = type_str; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + defines.push_back(type_upper); + defines.push_back(type_upper + "_SCALE_MIN"); + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); + + variant += "_"; + variant += type_str; + + defines.push_back(std::string("SRC_TYPE=") + type_str); + defines.push_back("DST_TYPE=f32"); + + if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || + key.src_type == GGML_TYPE_IQ4_NL) { + defines.push_back("BLOCK_SIZE=32u"); + } else if (key.src_type >= GGML_TYPE_Q2_K) { + defines.push_back("BLOCK_SIZE=256u"); + } else { + defines.push_back("BLOCK_SIZE=1u"); + } + break; + } + } + + if (key.vectorized) { + variant += "_vec"; + } + + defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_get_rows, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + get_rows_pipelines[key] = pipeline; + return get_rows_pipelines[key]; } -}; -struct ggml_webgpu_unary_shader_lib_context { - ggml_webgpu_unary_pipeline_key key; - uint32_t max_wg_size; -}; + webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace }; -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_unary_shader_lib_context & context) { - std::vector defines; - std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) : - ggml_op_name((ggml_op) context.key.op); - // Operation-specific behavior - defines.push_back(variant); - - switch (context.key.type) { - case GGML_TYPE_F32: - defines.push_back("TYPE_F32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("TYPE_F16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported type for unary shader"); - } - - if (context.key.inplace) { - defines.push_back("INPLACE"); - variant += "_inplace"; - } - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} + auto it = scale_pipelines.find(key); + if (it != scale_pipelines.end()) { + return it->second; + } -/** Binary **/ + std::vector defines; + std::string variant = "scale"; -struct ggml_webgpu_binary_pipeline_key { - int type; - int op; - bool inplace; - bool overlap; + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } - bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { - return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_scale, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + scale_pipelines[key] = pipeline; + return scale_pipelines[key]; } -}; -struct ggml_webgpu_binary_pipeline_key_hash { - size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.type); - ggml_webgpu_hash_combine(seed, key.op); - ggml_webgpu_hash_combine(seed, key.inplace); - ggml_webgpu_hash_combine(seed, key.overlap); - return seed; + webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; + + auto it = pad_pipelines.find(key); + if (it != pad_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "pad"; + + if (key.circular) { + defines.push_back("CIRCULAR"); + variant += "_circular"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_pad, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + pad_pipelines[key] = pipeline; + return pad_pipelines[key]; } -}; -struct ggml_webgpu_binary_shader_lib_context { - ggml_webgpu_binary_pipeline_key key; - uint32_t max_wg_size; + webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_vec_pipeline_key key = { + .src0_type = context.src0->type, + .src1_type = context.src1->type, + // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float + .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0, + }; + + auto it = mul_mat_vec_pipelines.find(key); + if (it != mul_mat_vec_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat_vec"; + + // src1 type (vector) + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); + } + + // src0 type (matrix row) + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); + break; + default: + { + // Quantized types: use helpers but accumulate in f16 + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + + // For fast path we always dequantize from f16 inside the shader + defines.push_back("SRC0_INNER_TYPE=f16"); + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); + defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + + auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->tile_k = tile_k; + decisions->outputs_per_wg = outputs_per_wg; + decisions->vec_size = key.vectorized ? 4 : 1; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_vec_pipelines[key] = pipeline; + return mul_mat_vec_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_pipeline_key key = { + .src0_type = context.src0->type, + .src1_type = context.src1->type, + .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0, + .use_subgroup_matrix = context.supports_subgroup_matrix + }; + + auto it = mul_mat_fast_pipelines.find(key); + if (it != mul_mat_fast_pipelines.end()) { + return it->second; + } + + const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile; + std::vector defines; + std::string variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile"; + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f16"; + break; + default: + { + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("INIT_SRC0_SHMEM_" + type_upper); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + + // Use f16 inside the shader for quantized types + defines.push_back("SRC0_INNER_TYPE=f16"); + + variant += std::string("_") + src0_name; + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + // Tiles + defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); + defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); + + // Subgroup matrix specifics + if (key.use_subgroup_matrix) { + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); + defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); + defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u"); + defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u"); + defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u"); + defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u"); + } + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + if (key.vectorized) { + variant += "_vectorized"; + } + + if (!key.use_subgroup_matrix) { + defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); + defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + } + + auto processed = preprocessor.preprocess(shader_src, defines); + + auto decisions = std::make_shared(); + decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; + decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; + decisions->use_subgroup_matrix = key.use_subgroup_matrix; + if (key.use_subgroup_matrix) { + decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M; + decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N; + decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; + decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; + decisions->wg_size = context.max_subgroup_size; + } else { + decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; + decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE; + } + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_fast_pipelines[key] = pipeline; + return mul_mat_fast_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, + .src1_type = context.src1->type }; + + auto it = mul_mat_legacy_pipelines.find(key); + if (it != mul_mat_legacy_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat"; + + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat legacy shader"); + } + + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_TYPE=f32"); + defines.push_back("FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_TYPE=f16"); + defines.push_back("FLOAT"); + variant += "_f16"; + break; + default: + { + // quantized types + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + defines.push_back(type_upper); + defines.push_back(type_upper + "_SCALE_MIN"); + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); + + variant += std::string("_") + src0_name; + break; + } + } + + auto processed = preprocessor.preprocess(wgsl_mul_mat, defines); + + auto decisions = std::make_shared(); + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_legacy_pipelines[key] = pipeline; + return mul_mat_legacy_pipelines[key]; + } + + webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool is_unary = context.dst->op == GGML_OP_UNARY; + const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; + ggml_webgpu_unary_pipeline_key key = { + .type = context.dst->type, + .op = op, + .is_unary = is_unary, + .inplace = context.inplace, + }; + + auto it = unary_pipelines.find(key); + if (it != unary_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = + key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op); + defines.push_back(variant); + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for unary shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_unary, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + unary_pipelines[key] = pipeline; + return unary_pipelines[key]; + } + + webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_binary_pipeline_key key = { + .type = context.dst->type, + .op = context.dst->op, + .inplace = context.inplace, + .overlap = context.overlap, + }; + + auto it = binary_pipelines.find(key); + if (it != binary_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string op_name = ggml_op_name((ggml_op) key.op); + std::string variant = op_name; + + defines.push_back(std::string("OP_") + op_name); + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for binary shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_binary, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + binary_pipelines[key] = pipeline; + return binary_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool has_mask = context.src3 != nullptr; + const bool has_sinks = context.src4 != nullptr; + + bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && + (context.src1->ne[1] % context.sg_mat_n == 0); + + ggml_webgpu_flash_attn_pipeline_key key = { + .kv_type = context.src1->type, + .head_dim_qk = (uint32_t) context.src0->ne[0], + .head_dim_v = (uint32_t) context.src2->ne[0], + .kv_direct = kv_direct, + .has_mask = has_mask, + .has_sinks = has_sinks, + .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f, + }; + + auto it = flash_attn_pipelines.find(key); + if (it != flash_attn_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "flash_attn"; + + switch (key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(key.kv_type); + + if (key.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_tile = + std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k, + context.wg_mem_limit_bytes, context.max_subgroup_size }), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + if (key.kv_direct) { + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + + uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + auto processed = preprocessor.preprocess(wgsl_flash_attn, defines); + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; + } + + private: + static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, + std::string shader_code, + std::string label) { + wgpu::ShaderSourceWGSL shader_source; + shader_source.code = shader_code.c_str(); + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &shader_source; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.label = label.c_str(); + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout + return { device.CreateComputePipeline(&pipeline_desc), label }; + } + + static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = + (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!context.key.kv_direct) { + bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); + } + if (context.key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; + } }; -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_binary_shader_lib_context & context) { - std::vector defines; - std::string op_name = ggml_op_name((ggml_op) context.key.op); - std::string variant = op_name; - - defines.push_back(std::string("OP_") + op_name); - - switch (context.key.type) { - case GGML_TYPE_F32: - defines.push_back("TYPE_F32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("TYPE_F16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported type for binary shader"); - } - - if (context.key.inplace) { - defines.push_back("INPLACE"); - variant += "_inplace"; - } else if (context.key.overlap) { - defines.push_back("OVERLAP"); - variant += "_overlap"; - } - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 32e12026..17bb2f47 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,7 +8,6 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" -#include "ggml-wgsl-shaders.hpp" #include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ @@ -23,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -69,50 +69,29 @@ /* Constants */ -// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed. -#define WEBGPU_MAX_WG_SIZE 288 - -#define WEBGPU_MUL_MAT_WG_SIZE 256 #define WEBGPU_NUM_PARAM_BUFS 16u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 -// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool +// Maximum number of in-flight submissions per-thread, to avoid exhausting the +// parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 -// For operations which process a row in parallel, this seems like a reasonable default +// For operations which process a row in parallel, this seems like a reasonable +// default #define WEBGPU_ROW_SPLIT_WG_SIZE 64 -// Matrix multiplication parameters - -// Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 8 -#define WEBGPU_MUL_MAT_TILE_N 8 -#define WEBGPU_MUL_MAT_WG_SIZE_M 8 -#define WEBGPU_MUL_MAT_WG_SIZE_N 8 -#define WEBGPU_MUL_MAT_TILE_K 32 - -// Subgroup matrix parameters -// The number of subgroups in the M dimension -#define WEBGPU_MUL_MAT_SUBGROUP_M 2 -// The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 2 -// The number of subgroup matrices each subgroup accumulates over -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 - -// Matrix-vector multiplication parameters -#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_TILE_K 256 +// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to +// implementations so this can be removed, necessary only for get_rows right now +#define WEBGPU_MAX_WG_SIZE 288 /* End Constants */ -// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. +// This is a "fake" base pointer, since WebGPU buffers do not have pointers to +// their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT // Always returns the base offset of a tensor, regardless of views. @@ -263,12 +242,6 @@ struct webgpu_gpu_profile_buf_pool { }; #endif -struct webgpu_pipeline { - wgpu::ComputePipeline pipeline; - std::string name; - std::shared_ptr context = nullptr; -}; - struct webgpu_command { wgpu::CommandBuffer commands; std::vector params_bufs; @@ -353,41 +326,18 @@ struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context webgpu_global_context global_ctx; - pre_wgsl::Preprocessor p; + std::unique_ptr shader_lib; webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; - std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized - std::map>> - mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - - std::unordered_map - flash_attn_pipelines; - - std::unordered_map argmax_pipelines; // key is vec4 - std::unordered_map argsort_pipelines; // key is order (asc/desc) - std::unordered_map argsort_merge_pipelines; // key is order (asc/desc) - std::unordered_map cumsum_pipelines; // key is fixed, no variants yet - std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet - - std::unordered_map - set_rows_pipelines; - std::map> get_rows_pipelines; // src_type, vectorized - - std::map> cpy_pipelines; // src_type, dst_type - - std::unordered_map - binary_pipelines; + std::map> cpy_pipelines; // src_type, dst_type std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace std::map>> glu_pipelines; // glu_op, type, split - std::map scale_pipelines; // inplace + std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - std::unordered_map - unary_pipelines; - std::unordered_map pad_pipelines; size_t memset_bytes_per_thread; }; @@ -429,25 +379,6 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ -// Process a WGSL shader string, replacing tokens of the form {{KEY}} with -// the corresponding values provided in `repls`. -static std::string ggml_webgpu_process_shader_repls(const char * src, - const std::map & repls) { - if (!src) { - return std::string(); - } - std::string s = src; - for (const auto & kv : repls) { - std::string token = "{{" + kv.first + "}}"; - size_t pos = 0; - while ((pos = s.find(token, pos)) != std::string::npos) { - s.replace(pos, token.length(), kv.second); - pos += kv.second.length(); - } - } - return s; -} - static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, const char * shader_code, const char * label, @@ -495,8 +426,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, std::vector & futures, bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, - // inflight_max may be 0, meaning that we must wait on all futures. + // If we have too many in-flight submissions, wait on the oldest one first. If + // there are many threads, inflight_max may be 0, meaning that we must wait on + // all futures. uint64_t timeout_ms = block ? UINT64_MAX : 0; uint32_t inflight_threads = ctx->inflight_threads; uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); @@ -681,7 +613,8 @@ static webgpu_command ggml_backend_webgpu_build_multi( encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); } - // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + // If there are SET_ROWS operations in this submission, copy their error + // buffers to the host. if (set_rows_error_bufs) { encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, set_rows_error_bufs->host_buf.GetSize()); @@ -900,24 +833,11 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g } static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - const bool circular = ggml_get_op_params_i32(dst, 8) != 0; - - ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular }; - ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { - .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->pad_pipelines.find(pipeline_key); - if (it != ctx->pad_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->pad_pipelines.emplace(pipeline_key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -971,36 +891,25 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { - // For set rows specifically, we need to check if src and idx are empty tensors. + // For set rows specifically, we need to check if src and idx are empty + // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { return std::nullopt; } - ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type, - .vec4 = src->ne[0] % 4 == 0, - .i64_idx = idx->type == GGML_TYPE_I64 }; - - ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { - .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = idx, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->set_rows_pipelines.find(key); - if (it != ctx->set_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->set_rows_pipelines.emplace(key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); std::optional error_bufs = std::nullopt; - if (key.i64_idx) { + if (decisions->i64_idx) { error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { error_bufs->host_buf.Unmap(); @@ -1038,13 +947,13 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - if (key.i64_idx) { + if (decisions->i64_idx) { entries.push_back( { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() }); } uint32_t threads; - if (key.vec4) { + if (decisions->vec4) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; @@ -1054,26 +963,47 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, error_bufs); } +// Workgroup size is a common constant +static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = wg_size; + return constants; +} + static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { - std::vector params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), - (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), - (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), - (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Shape of dst - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], - // Shape of idx - (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, + .max_wg_size = WEBGPU_MAX_WG_SIZE, }; + webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + (uint32_t) (idx->ne[1]), + (uint32_t) (idx->ne[2]) }; + std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), @@ -1089,10 +1019,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size); - uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0; - webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized]; return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1100,25 +1028,74 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + // Determine if this is a mat-vec operation + bool is_vec = (dst->ne[1] == 1); + + // Determine if we should use fast path + bool use_fast = false; + switch (src1->type) { + case GGML_TYPE_F16: + use_fast = (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + use_fast = true; + break; + default: + break; + } + break; + default: + break; + } + + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, + }; + + // Get or create pipeline + webgpu_pipeline pipeline; + + if (use_fast && is_vec) { + pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx); + } else if (use_fast) { + pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); + } else { + pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx); + } + + // Build params std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) dst->ne[0], // number of rows in result (M, transposed) - (uint32_t) dst->ne[1], // number of columns in result (N) - (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) - (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 - (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 - (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2 - (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2 - (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3 - (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3 - (uint32_t) src0->ne[2], // batch size in dimension 2 - (uint32_t) src0->ne[3], // batch size in dimension 3 - (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2 - (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3 + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) src0->ne[0], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src0->ne[2], + (uint32_t) src0->ne[3], + (uint32_t) (src1->ne[2] / src0->ne[2]), + (uint32_t) (src1->ne[3] / src0->ne[3]) }; + // Build bind group entries std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src0), @@ -1134,68 +1111,44 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; - webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0]; - - uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE); + // Calculate workgroup dimensions + uint32_t wg_x = 1; uint32_t wg_y = 1; - bool use_fast = false; - switch (src1->type) { - case GGML_TYPE_F16: - use_fast = (src0->type == GGML_TYPE_F16); - break; - case GGML_TYPE_F32: - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - use_fast = true; - break; - default: - break; - } - break; - default: - break; - } - - if (use_fast) { - int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - if (dst->ne[1] == 1) { - // We don't support vectorized mul_mat_vec for quantized types - vectorized = vectorized && (src0->type < 2); - pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; - uint32_t batches = dst->ne[2] * dst->ne[3]; - uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); - uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + if (use_fast && is_vec) { + auto decisions = static_cast(pipeline.context.get()); + + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + } else if (use_fast) { + auto decisions = static_cast(pipeline.context.get()); + + // Fast-path tiled/subgroup calculations + uint32_t wg_m, wg_n; + if (decisions->use_subgroup_matrix) { + uint32_t wg_m_sg_tile = + decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m; + wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); + uint32_t wg_n_sg_tile = + decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); } else { - pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; - uint32_t wg_m; - uint32_t wg_n; -#ifndef __EMSCRIPTEN__ - if (ctx->global_ctx->capabilities.supports_subgroup_matrix) { - // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * - ctx->global_ctx->capabilities.sg_mat_m; - wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); - uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * - ctx->global_ctx->capabilities.sg_mat_n; - wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); - } else { -#endif - uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; - uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; - wg_m = CEIL_DIV(dst->ne[0], tile_m_s); - wg_n = CEIL_DIV(dst->ne[1], tile_n_s); -#ifndef __EMSCRIPTEN__ - } -#endif - - wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m; + uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + wg_n = CEIL_DIV(dst->ne[1], tile_n_s); } + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + } else { // legacy + auto decisions = static_cast(pipeline.context.get()); + uint32_t wg_size = decisions->wg_size; + wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); + wg_y = 1; } + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); } @@ -1283,40 +1236,22 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = Q, + .src1 = K, + .src2 = V, + .src3 = mask, + .src4 = sinks, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, }; - webgpu_pipeline pipeline; - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size - }; - - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->flash_attn_pipelines.emplace(key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1329,27 +1264,16 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); - int op = is_unary ? (int) ggml_get_unary_op(dst) : dst->op; - ggml_webgpu_unary_pipeline_key pipeline_key = { - .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace - }; - ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { - .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = inplace, }; - webgpu_pipeline pipeline; - auto it = ctx->unary_pipelines.find(pipeline_key); - if (it != ctx->unary_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->unary_pipelines.emplace(pipeline_key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1421,30 +1345,18 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_binary_pipeline_key pipeline_key = { - .type = dst->type, - .op = dst->op, - .inplace = flags.inplace, - .overlap = flags.overlap, - }; - ggml_webgpu_binary_shader_lib_context shader_lib_ctx = { - .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = flags.inplace, + .overlap = flags.overlap, }; - webgpu_pipeline pipeline; - auto it = ctx->binary_pipelines.find(pipeline_key); - if (it != ctx->binary_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->binary_pipelines.emplace(pipeline_key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1669,8 +1581,20 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, } static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - int inplace = ggml_webgpu_tensor_equal(src, dst); + bool inplace = ggml_webgpu_tensor_equal(src, dst); + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = inplace, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + // params unchanged std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1688,12 +1612,14 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, *(uint32_t *) &dst->op_params[1] // bias }; + // bindgroups unchanged std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), .offset = ggml_webgpu_tensor_align_offset(ctx, src), .size = ggml_webgpu_tensor_binding_size(ctx, src) } }; + if (!inplace) { entries.push_back({ .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), @@ -1701,9 +1627,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params, - entries, wg_x); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, @@ -1796,63 +1721,30 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { - .vec4 = src->ne[0] % 4 == 0, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); - if (it != ctx->argmax_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline); - } - uint32_t wg_x = ggml_nelements(dst); + webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); + uint32_t wg_x = ggml_nelements(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool is_top_k = dst->op == GGML_OP_TOP_K; - // ascending order is 0, descending order is 1 - const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0); + bool is_top_k = dst->op == GGML_OP_TOP_K; - ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .order = order }; - webgpu_pipeline argsort_pipeline; - auto it = ctx->argsort_pipelines.find(order); - if (it != ctx->argsort_pipelines.end()) { - argsort_pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx); - argsort_pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - argsort_pipeline.context = processed.decisions; - ctx->argsort_pipelines.emplace(order, argsort_pipeline); - } - auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); - - webgpu_pipeline argsort_merge_pipeline; - it = ctx->argsort_merge_pipelines.find(order); - if (it != ctx->argsort_merge_pipelines.end()) { - argsort_merge_pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx); - argsort_merge_pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - argsort_merge_pipeline.context = processed.decisions; - ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline); - } + webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); + auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); + + webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx); const uint32_t src_ne0 = (uint32_t) src->ne[0]; const uint32_t nrows = (uint32_t) ggml_nrows(src); @@ -2011,22 +1903,15 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { - .vec4 = false, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; - webgpu_pipeline pipeline; - auto it = ctx->cumsum_pipelines.find(1); - if (it != ctx->cumsum_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->cumsum_pipelines.emplace(1, pipeline); - } - uint32_t wg_x = ggml_nrows(dst); + + webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); + uint32_t wg_x = ggml_nrows(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2052,22 +1937,12 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { - .vec4 = false, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->sum_rows_pipelines.find(1); - if (it != ctx->sum_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->sum_rows_pipelines.emplace(1, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); + uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2233,7 +2108,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe size_t offset, size_t size) { if (size == 0) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do."); + WEBGPU_LOG_DEBUG( + "ggml_backend_webgpu_buffer_memset_tensor: size is zero, " + "nothing to do."); return; } @@ -2310,7 +2187,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t final_size = size; if (size % 4 != 0) { - // If size is not a multiple of 4, we need to round it up to the next multiple of 4 + // If size is not a multiple of 4, we need to round it up to the next + // multiple of 4 final_size = size + (4 - (size % 4)); } @@ -2364,7 +2242,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, - /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor + /* .reset = */ NULL, // TODO: optional, think it coordinates with + // .init_tensor }; /* End GGML Backend Buffer Interface */ @@ -2401,7 +2280,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_ return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; } -// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding. +// maxBufferSize might be larger, but you can't bind more than +// maxStorageBufferBindingSize to a single binding. static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { ggml_backend_webgpu_device_context * dev_ctx = static_cast(buft->device->context); @@ -2487,14 +2367,6 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { return reinterpret_cast((void *) guid_str); } -// Workgroup size is a common constant -static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = wg_size; - return constants; -} - static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -2509,207 +2381,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - // Q4/Q5/Q8 classic quantizations - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); - - // K-quantizations - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); - - // IQ quantizations (2-, 3-, 4-bit variants) - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); - - // 1-bit and 4-bit IQ variants - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); - - std::string proc_mul_mat_f32_f32; - std::string proc_mul_mat_f32_f32_vec; - std::string proc_mul_mat_f16_f32; - std::string proc_mul_mat_f16_f32_vec; - std::string proc_mul_mat_f16_f16; - std::string proc_mul_mat_f16_f16_vec; - std::string proc_mul_mat_q4_0_f32; - std::string proc_mul_mat_q4_0_f32_vec; - - std::vector mul_mat_constants; -#ifndef __EMSCRIPTEN__ - if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) { - std::map sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = - std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size); - sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); - sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); - sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); - sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); - sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k); - proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); - proc_mul_mat_f32_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); - proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); - proc_mul_mat_f16_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); - proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); - proc_mul_mat_f16_f16_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - proc_mul_mat_q4_0_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); - proc_mul_mat_q4_0_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); - } else { -#endif - mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K }); - mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); - mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); - - std::map reg_repls; - reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); - reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); - - proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); - proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); - proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); - proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); - proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); - proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); - proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); - proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); -#ifndef __EMSCRIPTEN__ - } -#endif - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); - - std::vector mul_mat_vec_constants(3); - mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; - mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE; - mul_mat_vec_constants[1].key = "TILE_K"; - mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K; - mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; - mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); -} - -static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); -} - static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); @@ -2816,15 +2487,6 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } -static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->scale_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants); - webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace, - "scale_f32_inplace", constants); -} - static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); @@ -3015,6 +2677,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); @@ -3023,13 +2686,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); - ggml_webgpu_init_get_rows_pipeline(webgpu_ctx); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); ggml_webgpu_init_rope_pipeline(webgpu_ctx); ggml_webgpu_init_glu_pipeline(webgpu_ctx); - ggml_webgpu_init_scale_pipeline(webgpu_ctx); ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers @@ -3071,11 +2731,11 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, - /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size, - /* .is_host = */ NULL, // defaults to false + /* .alloc_buffer = */ + ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ + ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ + ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ + ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false }, /* .device = */ dev, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 389c97bb..9a5b18eb 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -1,5 +1,4 @@ -#decl(BYTE_HELPERS) - +#ifdef BYTE_HELPERS fn get_byte(value: u32, index: u32) -> u32 { return (value >> (index * 8)) & 0xFF; } @@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 { fn get_byte_i32(value: u32, index: u32) -> i32 { return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; } +#endif -#enddecl(BYTE_HELPERS) - -#decl(Q4_0_T) +#ifdef Q4_0_T struct q4_0 { d: f16, qs: array }; -#enddecl(Q4_0_T) +#endif -#decl(Q4_1_T) +#ifdef Q4_1_T struct q4_1 { d: f16, m: f16, qs: array }; -#enddecl(Q4_1_T) +#endif -#decl(Q5_0_T) +#ifdef Q5_0_T struct q5_0 { d: f16, qh: array, qs: array }; -#enddecl(Q5_0_T) +#endif -#decl(Q5_1_T) +#ifdef Q5_1_T struct q5_1 { d: f16, m: f16, qh: u32, qs: array }; -#enddecl(Q5_1_T) +#endif -#decl(Q8_0_T) +#ifdef Q8_0_T struct q8_0 { d: f16, qs: array }; -#enddecl(Q8_0_T) +#endif -#decl(Q8_1_T) +#ifdef Q8_1_T struct q8_1 { d: f16, m: f16, qs: array }; -#enddecl(Q8_1_T) +#endif -#decl(Q2_K_T) -struct q2_k { +#ifdef Q2_K_T +struct q2_K { scales: array, qs: array, d: f16, dmin: f16 }; -#enddecl(Q2_K_T) +#endif -#decl(Q3_K_T) -struct q3_k { +#ifdef Q3_K_T +struct q3_K { hmask: array, qs: array, scales: array, d: f16 }; -#enddecl(Q3_K_T) - -#decl(Q45_K_SCALE_MIN) +#endif +#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN) fn get_scale_min(is: u32, scales: array) -> vec2 { if (is < 4) { let sc_byte = get_byte(scales[is / 4], is % 4); @@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array) -> vec2 { return vec2(f32(sc), f32(m)); } } - -#enddecl(Q45_K_SCALE_MIN) - -#decl(Q4_K_T) -struct q4_k { +#endif +#ifdef Q4_K_T +struct q4_K { d: f16, dmin: f16, scales: array, qs: array }; -#enddecl(Q4_K_T) +#endif -#decl(Q5_K_T) -struct q5_k { +#ifdef Q5_K_T +struct q5_K { d: f16, dmin: f16, scales: array, qh: array, qs: array }; -#enddecl(Q5_K_T) +#endif -#decl(Q6_K_T) -struct q6_k { +#ifdef Q6_K_T +struct q6_K { ql: array, qh: array, scales: array, d: f16 }; -#enddecl(Q6_K_T) +#endif -#decl(IQ2_XXS_T) +#ifdef IQ2_XXS_T struct iq2_xxs { d: f16, qs: array }; -#enddecl(IQ2_XXS_T) +#endif -#decl(IQ2_XS_T) +#ifdef IQ2_XS_T struct iq2_xs { d: f16, qs: array, scales: array }; -#enddecl(IQ2_XS_T) +#endif -#decl(IQ2_S_T) +#ifdef IQ2_S_T struct iq2_s { d: f16, qs: array, qh: array, scales: array }; -#enddecl(IQ2_S_T) +#endif -#decl(IQ3_XSS_T) +#ifdef IQ3_XXS_T struct iq3_xxs { d: f16, qs: array }; -#enddecl(IQ3_XSS_T) +#endif -#decl(IQ3_S_T) +#ifdef IQ3_S_T struct iq3_s { d: f16, qs: array, @@ -161,41 +156,41 @@ struct iq3_s { signs: array, scales: array }; -#enddecl(IQ3_S_T) +#endif -#decl(IQ1_S_T) +#ifdef IQ1_S_T struct iq1_s { d: f16, qs: array, qh: array }; -#enddecl(IQ1_S_T) +#endif -#decl(IQ1_M_T) +#ifdef IQ1_M_T struct iq1_m { qs: array, qh: array, scales: array }; -#enddecl(IQ1_M_T) +#endif -#decl(IQ4_NL_T) +#ifdef IQ4_NL_T struct iq4_nl { d: f16, qs: array, }; -#enddecl(IQ4_NL_T) +#endif -#decl(IQ4_XS_T) +#ifdef IQ4_XS_T struct iq4_xs { d: f16, scales_h: f16, scales_l: u32, qs: array }; -#enddecl(IQ4_XS_T) +#endif -#decl(IQ23_TABLES) +#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES) const kmask_iq2xs : array = array( 0x08040201u, // 1, 2, 4, 8 0x80402010u // 16, 32, 64, 128 @@ -211,9 +206,9 @@ const ksigns_iq2xs: array = array( 0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c, 0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc ); -#enddecl(IQ23_TABLES) +#endif -#decl(IQ2_XXS_GRID) +#ifdef IQ2_XXS_GRID const iq2xxs_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808, @@ -280,9 +275,9 @@ const iq2xxs_grid = array( 0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819, 0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19 ); -#enddecl(IQ2_XXS_GRID) +#endif -#decl(IQ2_XS_GRID) +#ifdef IQ2_XS_GRID const iq2xs_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, @@ -413,9 +408,9 @@ const iq2xs_grid = array( 0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19, 0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b ); -#enddecl(IQ2_XS_GRID) +#endif -#decl(IQ2_S_GRID) +#ifdef IQ2_S_GRID const iq2s_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, @@ -674,10 +669,9 @@ const iq2s_grid = array( 0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b, 0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b ); -#enddecl(IQ2_S_GRID) - -#decl(IQ3_XSS_GRID) +#endif +#ifdef IQ3_XXS_GRID const iq3xxs_grid = array( 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, @@ -712,10 +706,9 @@ const iq3xxs_grid = array( 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04 ); -#enddecl(IQ3_XSS_GRID) - -#decl(IQ3_S_GRID) +#endif +#ifdef IQ3_S_GRID const iq3s_grid = array( 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, @@ -782,9 +775,9 @@ const iq3s_grid = array( 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101 ); -#enddecl(IQ3_S_GRID) +#endif -#decl(IQ1_GRID) +#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID) const IQ1_DELTA: f32 = 0.125; @@ -919,12 +912,12 @@ const iq1_grid = array( 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 ); -#enddecl(IQ1_GRID) +#endif -#decl(IQ4_GRID) +#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID) const kvalues_iq4nl = array( -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 ); -#enddecl(IQ4_GRID) +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index d61df5bb..8b5cfe71 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -56,12 +56,46 @@ def expand_includes(shader, input_dir): return include_pattern.sub(replacer, shader) -def write_shader(shader_name, shader_code, output_dir, outfile): +def chunk_shader(shader_code, max_chunk_len=60000): + """Split shader_code into safe raw-string sized chunks.""" + return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)] + + +def raw_delim(shader_code): + """Pick a raw-string delimiter that does not appear in the shader.""" + delim = "wgsl" + while f"){delim}\"" in shader_code: + delim += "_x" + return delim + + +def write_shader(shader_name, shader_code, output_dir, outfile, input_dir): + shader_code = expand_includes(shader_code, input_dir) + if output_dir: wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl") with open(wgsl_filename, "w", encoding="utf-8") as f_out: f_out.write(shader_code) - outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n') + + delim = raw_delim(shader_code) + chunks = chunk_shader(shader_code) + + if len(chunks) == 1: + outfile.write(f'const char* wgsl_{shader_name} = R"{delim}({shader_code}){delim}";\n\n') + else: + for idx, chunk in enumerate(chunks): + outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R"{delim}({chunk}){delim}";\n\n') + outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\n') + outfile.write(' static const std::string s = []{\n') + outfile.write(' std::string tmp;\n') + outfile.write(f' tmp.reserve({len(shader_code)});\n') + for idx in range(len(chunks)): + outfile.write(f' tmp.append(wgsl_{shader_name}_part{idx});\n') + outfile.write(' return tmp;\n') + outfile.write(' }();\n') + outfile.write(' return s;\n') + outfile.write('}\n') + outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n') def generate_variants(fname, input_dir, output_dir, outfile): @@ -74,7 +108,7 @@ def generate_variants(fname, input_dir, output_dir, outfile): try: variants = ast.literal_eval(extract_block(text, "VARIANTS")) except ValueError: - write_shader(shader_base_name, text, output_dir, outfile) + write_shader(shader_base_name, text, output_dir, outfile, input_dir) else: try: decls_map = parse_decls(extract_block(text, "DECLS")) @@ -123,7 +157,7 @@ def generate_variants(fname, input_dir, output_dir, outfile): output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] else: output_name = shader_base_name - write_shader(output_name, final_shader, output_dir, outfile) + write_shader(output_name, final_shader, output_dir, outfile, input_dir) def main(): @@ -137,7 +171,8 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) with open(args.output_file, "w", encoding="utf-8") as out: - out.write("// Auto-generated shader embedding\n\n") + out.write("// Auto-generated shader embedding\n") + out.write("#include \n\n") for fname in sorted(os.listdir(args.input_dir)): if fname.endswith(".wgsl"): generate_variants(fname, args.input_dir, args.output_dir, out) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl deleted file mode 100644 index f80ce1fc..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +++ /dev/null @@ -1,874 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_SUFFIX": "f32_vec", - "REPLS": { - "TYPE" : "vec4", - "DST_TYPE": "vec4", - "BLOCK_SIZE": 4 - }, - "DECLS": ["F32_VEC"] - }, - { - "REPLS": { - "TYPE" : "f32", - "DST_TYPE": "f32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F32"] - }, - { - "REPLS": { - "TYPE" : "f16", - "DST_TYPE": "f32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F16"] - }, - { - "REPLS": { - "TYPE" : "i32", - "DST_TYPE": "i32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["I32"] - }, - { - "REPLS": { - "TYPE" : "q4_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] - }, - { - "REPLS": { - "TYPE" : "q4_1", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] - }, - { - "REPLS": { - "TYPE" : "q5_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] - }, - { - "REPLS": { - "TYPE" : "q5_1", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] - }, - { - "REPLS": { - "TYPE" : "q8_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] - }, - { - "REPLS": { - "TYPE" : "q2_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] - }, - { - "REPLS": { - "TYPE" : "q3_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] - }, - { - "REPLS": { - "TYPE" : "q4_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] - }, - { - "REPLS": { - "TYPE" : "q5_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] - }, - { - "REPLS": { - "TYPE" : "q6_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] - }, - { - "REPLS": { - "TYPE" : "iq2_xxs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] - }, - { - "REPLS": { - "TYPE" : "iq2_xs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] - }, - { - "REPLS": { - "TYPE": "iq2_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] - }, - { - "REPLS": { - "TYPE": "iq3_xxs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] - }, - { - "REPLS": { - "TYPE": "iq3_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] - }, - { - "REPLS": { - "TYPE": "iq1_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] - }, - { - "REPLS": { - "TYPE": "iq1_m", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] - }, - { - "REPLS": { - "TYPE": "iq4_nl", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] - }, - { - "REPLS": { - "TYPE": "iq4_xs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(F32_VEC) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; -} -#enddecl(F32_VEC) - -#decl(F32) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - dst[dst_base + offset] = src[src_base + offset]; -} -#enddecl(F32) - -#decl(F16) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - dst[dst_base + offset] = f32(src[src_base + offset]); -} -#enddecl(F16) - -#decl(I32) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - dst[dst_base + offset] = src[src_base + offset]; -} -#enddecl(I32) - -#decl(Q4_0) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q4_0 = src[src_base + offset]; - let d = f32(block_q4_0.d); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; - let dst_offset = dst_base + offset * 32 + j * 4 + k; - dst[dst_offset] = q_lo; - dst[dst_offset + 16] = q_hi; - } - } -} -#enddecl(Q4_0) - -#decl(Q4_1) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q4_1 = src[src_base + offset]; - let d = f32(block_q4_1.d); - let m = f32(block_q4_1.m); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q4_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - let dst_offset = dst_base + offset * 32 + j * 4 + k; - dst[dst_offset] = q_lo; - dst[dst_offset + 16] = q_hi; - } - } -} -#enddecl(Q4_1) - -#decl(Q5_0) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q5_0 = src[src_base + offset]; - let d = f32(block_q5_0.d); - let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - let dst_offset = dst_base + offset * 32 + j * 4 + k; - dst[dst_offset] = q_lo; - dst[dst_offset + 16] = q_hi; - } - } -} - -#enddecl(Q5_0) - -#decl(Q5_1) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q5_1 = src[src_base + offset]; - let d = f32(block_q5_1.d); - let m = f32(block_q5_1.m); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q5_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; - let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; - let dst_offset = dst_base + offset * 32 + j * 4 + k; - dst[dst_offset] = q_lo; - dst[dst_offset + 16] = q_hi; - } - } -} -#enddecl(Q5_1) - -#decl(Q8_0) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q8_0 = src[src_base + offset]; - let d = f32(block_q8_0.d); - for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - let dst_offset = dst_base + offset * 32 + j * 4 + k; - dst[dst_offset] = q_val; - } - } -} -#enddecl(Q8_0) - -#decl(Q2_K) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var dst_i = dst_base + offset * 256; - var is: u32 = 0; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(block.scales[is / 4], is % 4); - is++; - let dl = d * f32(sc & 0xF); - let ml = m * f32(sc >> 4); - for (var l: u32 = 0u; l < 16; l++) { - let q_idx = q_b_idx + k + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 3; - dst[dst_i] = (f32(qs_val) * dl - ml); - dst_i++; - } - } - } - } -} -#enddecl(Q2_K) - -#decl(Q3_K) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes - let kmask1: u32 = 0x03030303; - let kmask2: u32 = 0x0f0f0f0f; - var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - - // convert arrays of f16 -> u32 - var hmask_vals: array; - for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); - } - var qs_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); - } - - var dst_i = dst_base + offset * 256; - var is: u32 = 0; - var m: u32 = 1; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(scale_vals[is / 4], is % 4); - is++; - let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; - let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3; - dst[dst_i] = (f32(qs_val) - hm) * dl; - dst_i++; - } - } - m <<= 1; - } - } -} -#enddecl(Q3_K) - -#decl(Q4_K) -// 8 blocks of 32 elements each -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var dst_i = dst_base + offset * 256; - var is: u32 = 0; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 0xF; - dst[dst_i] = (f32(qs_val) * dl - ml); - dst_i++; - } - } - } -} -#enddecl(Q4_K) - -#decl(Q5_K) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var dst_i = dst_base + offset * 256; - var is: u32 = 0; - var u: u32 = 1; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qh_byte = get_byte(block.qh[l / 4], l % 4); - let qs_val = (q_byte >> shift) & 0xF; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); - dst[dst_i] = (f32(qs_val) + qh_val) * dl - ml; - dst_i++; - } - u <<= 1; - } - } -} -#enddecl(Q5_K) - -#decl(Q6_K) -// 16 blocks of 16 elements each -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - - // convert arrays of f16 -> u32 - var ql_vals: array; - for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); - } - var qh_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); - } - var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } - - var dst_i = dst_base + offset * 256; - var qh_b_idx: u32 = 0; - var sc_b_idx: u32 = 0; - for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { - for (var l: u32 = 0; l < 32; l++) { - let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); - let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); - let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); - - let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; - let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; - let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; - let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; - - let is = l/16; - let is1 = sc_b_idx + is; - let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); - let is2 = sc_b_idx + is + 2; - let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); - let is3 = sc_b_idx + is + 4; - let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); - let is4 = sc_b_idx + is + 6; - let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); - - dst[dst_i + l] = (q1 * f32(sc1)) * d; - dst[dst_i + l + 32] = (q2 * f32(sc2)) * d; - dst[dst_i + l + 64] = (q3 * f32(sc3)) * d; - dst[dst_i + l + 96] = (q4 * f32(sc4)) * d; - } - dst_i += 128; - qh_b_idx += 32; - sc_b_idx += 8; - } -} - -#enddecl(Q6_K) - -#decl(IQ2_XXS) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - var dst_i = dst_base + offset * 256; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); - let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; - for (var l: u32 = 0; l < 4; l++) { - let ig = get_byte(aux0, l) * 8; - let is = (aux1 >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - dst[dst_i] = db * f32(g) * m; - dst_i++; - } - } - } -} -#enddecl(IQ2_XXS) - -#decl(IQ2_XS) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - var dst_i = dst_base + offset * 256; - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); - for (var ib: u32 = 0; ib < 32; ib += 4) { - let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); - let db = array( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); - let ig = (qs_val & 511) * 8; - let is = qs_val >> 9; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - dst[dst_i] = dl * f32(g) * m; - dst_i++; - } - } - } -} -#enddecl(IQ2_XS) - -#decl(IQ2_S) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - var dst_i = dst_base + offset * 256; - var qs_vals : array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); - } - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); - for (var ib: u32 = 0; ib < 8; ib ++) { - let s = get_byte(scale_vals[ib / 4], ib % 4); - let db = array( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - let qs_w = qs_vals[ib]; - for (var l: u32 = 0; l < 4; l++) { - let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; - let ig = (get_byte(qs_w, l) | qh_b) * 8; - let signs = get_byte(qs_vals[ib + 8], l); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - dst[dst_i] = dl * f32(g) * m; - dst_i++; - } - } - } -} - -#enddecl(IQ2_S) - -#decl(IQ3_XSS) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - var dst_i = dst_base + offset * 256; - for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); - let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; - for (var l: u32 = 0; l < 4; l++) { - let is = (sc_sign >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); - let ig1 = get_byte(ig_val, 0); - let ig2 = get_byte(ig_val, 1); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3xxs_grid[ig1], j); - let g2 = get_byte(iq3xxs_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - dst[dst_i] = db * f32(g1) * m1; - dst[dst_i + 4] = db * f32(g2) * m2; - dst_i++; - } - dst_i += 4; - } - } -} -#enddecl(IQ3_XSS) - -#decl(IQ3_S) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - var dst_i = dst_base + offset * 256; - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var sign_vals: array; - for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); - } - var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); - for (var ib: u32 = 0; ib < 4; ib++) { - let s = get_byte(scale_vals, ib); - let db = array( - d * (1.0 + 2.0 * f32(s & 0xF)), - d * (1.0 + 2.0 * f32(s >> 4)) - ); - for (var k: u32 = 0; k < 2; k++) { - let dl = db[k]; - let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); - let sign_w = sign_vals[ib * 2 + k]; - for (var l: u32 = 0; l < 4; l++) { - let signs = get_byte(sign_w, l); - let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); - let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); - let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3s_grid[ig1], j); - let g2 = get_byte(iq3s_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - dst[dst_i] = dl * f32(g1) * m1; - dst[dst_i + 4] = dl * f32(g2) * m2; - dst_i++; - } - dst_i += 4; - } - } - } -} -#enddecl(IQ3_S) - -#decl(IQ1_S) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - var dst_i = dst_base + offset * 256; - for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); - for (var l: u32 = 0; l < 4; l++) { - let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast(g << 30) >> 30; - dst[dst_i] = dl * (f32(gs) + delta); - dst_i++; - } - } - } -} - -#enddecl(IQ1_S) - -#decl(IQ1_M) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - - let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); - let d = f32(bitcast>(scale).x); - var dst_i = dst_base + offset * 256; - for (var ib: u32 = 0; ib < 8; ib++) { - let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; - let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; - let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; - var dl = array( - d * f32(2 * s1 + 1), - d * f32(2 * s2 + 1) - ); - - let qh = block.qh[ib / 2] >> (16 * (ib % 2)); - var idx = array( - get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), - get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), - get_byte(block.qs[ib], 2) | ((qh) & 0x700), - get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) - ); - var delta = array( - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) - ); - for (var l: u32 = 0; l < 4; l++) { - let ig = idx[l] * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast(g << 30) >> 30; - dst[dst_i] = dl[l/2] * (f32(gs) + delta[l]); - dst_i++; - } - } - } -} - -#enddecl(IQ1_M) - -#decl(IQ4_NL) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - var dst_i = dst_base + offset * 32; - var qs: array; - for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); - } - for (var j: u32 = 0; j < 16; j++) { - let qsb = get_byte(qs[j / 4], j % 4); - dst[dst_i] = d * f32(kvalues_iq4nl[qsb & 0xF]); - dst[dst_i + 16] = d * f32(kvalues_iq4nl[qsb >> 4]); - dst_i++; - } -} -#enddecl(IQ4_NL) - -#decl(IQ4_XS) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); - let scales_h = bitcast(vec2(block.scales_h, 0.0)); - var dst_i = dst_base + offset * 256; - for (var ib: u32 = 0; ib < 8; ib++) { - let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); - let dl = d * (f32(ls) - 32.0); - for (var j: u32 = 0; j < 16; j++) { - let iqs = ib * 16 + j; - let qsb = get_byte(block.qs[iqs / 4], iqs % 4); - dst[dst_i] = dl * f32(kvalues_iq4nl[qsb & 0xF]); - dst[dst_i + 16] = dl * f32(kvalues_iq4nl[qsb >> 4]); - dst_i++; - } - dst_i += 16; - } -} -#enddecl(IQ4_XS) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -DECLS - -@group(0) @binding(0) -var src: array<{{TYPE}}>; - -@group(0) @binding(1) -var idx: array; - -@group(0) @binding(2) -var dst: array<{{DST_TYPE}}>; - -struct Params { - offset_src: u32, // in elements - offset_idx: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_idx0: u32, - stride_idx1: u32, - stride_idx2: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Shape of dst - ne0: u32, - n_rows: u32, - ne2: u32, - ne3: u32, - - // Shape of idx - idx1: u32, - idx2: u32, -}; - -@group(0) @binding(3) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.n_rows * params.ne2 * params.ne3) { - return; - } - var i = gid.x; - let i_dst3 = i / (params.ne2 * params.n_rows); - - i = i % (params.ne2 * params.n_rows); - let i_dst2 = i / params.n_rows; - let i_dst1 = i % params.n_rows; - - let i_idx2 = i_dst3 % params.idx2; - let i_idx1 = i_dst2 % params.idx1; - let i_idx0 = i_dst1; - - let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; - - let idx_val = u32(idx[i_idx]); - - let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; - let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; - - for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) { - copy_elements(i_src_row, i_dst_row, i); - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl new file mode 100644 index 00000000..b10800e3 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -0,0 +1,668 @@ +enable f16; +#include "common_decls.tmpl" + +#ifdef F32_VEC +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; +} +#endif + +#ifdef F32 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + dst[dst_base + offset] = src[src_base + offset]; +} +#endif + +#ifdef F16 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + dst[dst_base + offset] = f32(src[src_base + offset]); +} +#endif + +#ifdef I32 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + dst[dst_base + offset] = src[src_base + offset]; +} +#endif + +#ifdef Q4_0 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_q4_0 = src[src_base + offset]; + let d = f32(block_q4_0.d); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; + let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16] = q_hi; + } + } +} +#endif + +#ifdef Q4_1 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_q4_1 = src[src_base + offset]; + let d = f32(block_q4_1.d); + let m = f32(block_q4_1.m); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = block_q4_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32((q_byte >> 4) & 0xF) * d + m; + let q_lo = f32(q_byte & 0xF) * d + m; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16] = q_hi; + } + } +} +#endif + +#ifdef Q5_0 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_q5_0 = src[src_base + offset]; + let d = f32(block_q5_0.d); + let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; + let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; + let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16] = q_hi; + } + } +} +#endif + +#ifdef Q5_1 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_q5_1 = src[src_base + offset]; + let d = f32(block_q5_1.d); + let m = f32(block_q5_1.m); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = block_q5_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; + let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; + let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; + let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16] = q_hi; + } + } +} +#endif + +#ifdef Q8_0 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_q8_0 = src[src_base + offset]; + let d = f32(block_q8_0.d); + for (var j: u32 = 0; j < 8; j++) { + let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_val; + } + } +} +#endif + +#ifdef Q2_K +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var dst_i = dst_base + offset * 256; + var is: u32 = 0; + // 2 halves of the block (128 elements each) + for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { + // 4 groups (each group has 2 blocks of 16 elements) + for (var shift: u32 = 0; shift < 8; shift += 2) { + // 2 blocks + for (var k: u32 = 0; k < 32; k += 16) { + let sc = get_byte(block.scales[is / 4], is % 4); + is++; + let dl = d * f32(sc & 0xF); + let ml = m * f32(sc >> 4); + for (var l: u32 = 0u; l < 16; l++) { + let q_idx = q_b_idx + k + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qs_val = (q_byte >> shift) & 3; + dst[dst_i] = (f32(qs_val) * dl - ml); + dst_i++; + } + } + } + } +} +#endif + +#ifdef Q3_K +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + + // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, + // and 2-bits from the last 4 bytes + let kmask1: u32 = 0x03030303; + let kmask2: u32 = 0x0f0f0f0f; + var scale_vals: array; + for (var i: u32 = 0; i < 4; i++) { + scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + } + var tmp: u32 = scale_vals[2]; + scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); + scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + // convert arrays of f16 -> u32 + var hmask_vals: array; + for (var i: u32 = 0; i < 8; i++) { + hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + } + var qs_vals: array; + for (var i: u32 = 0; i < 16; i++) { + qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + } + + var dst_i = dst_base + offset * 256; + var is: u32 = 0; + var m: u32 = 1; + // 2 halves of the block (128 elements each) + for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { + // 4 groups (each group has 2 blocks of 16 elements) + for (var shift: u32 = 0; shift < 8; shift += 2) { + // 2 blocks + for (var k: u32 = 0; k < 32; k += 16) { + let sc = get_byte(scale_vals[is / 4], is % 4); + is++; + let dl = d * (f32(sc) - 32.0); + for (var l: u32 = 0u; l < 16u; l++) { + let q_idx = q_b_idx + k + l; + let hm_idx = k + l; + let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); + let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); + let qs_val = (q_byte >> shift) & 3; + dst[dst_i] = (f32(qs_val) - hm) * dl; + dst_i++; + } + } + m <<= 1; + } + } +} +#endif + +#ifdef Q4_K +// 8 blocks of 32 elements each +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var dst_i = dst_base + offset * 256; + var is: u32 = 0; + // 2 blocks each iteration + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { + for (var shift: u32 = 0; shift < 8; shift += 4) { + let scale_min = get_scale_min(is, block.scales); + is++; + let dl = d * scale_min.x; + let ml = m * scale_min.y; + for (var l: u32 = 0; l < 32; l++) { + let q_idx = q_b_idx + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qs_val = (q_byte >> shift) & 0xF; + dst[dst_i] = (f32(qs_val) * dl - ml); + dst_i++; + } + } + } +} +#endif + +#ifdef Q5_K +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var dst_i = dst_base + offset * 256; + var is: u32 = 0; + var u: u32 = 1; + // 2 blocks each iteration + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { + for (var shift: u32 = 0; shift < 8; shift += 4) { + let scale_min = get_scale_min(is, block.scales); + is++; + let dl = d * scale_min.x; + let ml = m * scale_min.y; + for (var l: u32 = 0; l < 32; l++) { + let q_idx = q_b_idx + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qh_byte = get_byte(block.qh[l / 4], l % 4); + let qs_val = (q_byte >> shift) & 0xF; + let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); + dst[dst_i] = (f32(qs_val) + qh_val) * dl - ml; + dst_i++; + } + u <<= 1; + } + } +} +#endif + +#ifdef Q6_K +// 16 blocks of 16 elements each +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + + // convert arrays of f16 -> u32 + var ql_vals: array; + for (var i: u32 = 0; i < 32; i++) { + ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + } + var qh_vals: array; + for (var i: u32 = 0; i < 16; i++) { + qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + } + var scale_vals: array; + for (var i: u32 = 0; i < 4; i++) { + scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + } + + var dst_i = dst_base + offset * 256; + var qh_b_idx: u32 = 0; + var sc_b_idx: u32 = 0; + for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { + for (var l: u32 = 0; l < 32; l++) { + let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); + let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); + let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); + + let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; + let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; + let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; + let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; + + let is = l/16; + let is1 = sc_b_idx + is; + let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); + let is2 = sc_b_idx + is + 2; + let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); + let is3 = sc_b_idx + is + 4; + let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); + let is4 = sc_b_idx + is + 6; + let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); + + dst[dst_i + l] = (q1 * f32(sc1)) * d; + dst[dst_i + l + 32] = (q2 * f32(sc2)) * d; + dst[dst_i + l + 64] = (q3 * f32(sc3)) * d; + dst[dst_i + l + 96] = (q4 * f32(sc4)) * d; + } + dst_i += 128; + qh_b_idx += 32; + sc_b_idx += 8; + } +} +#endif + +#ifdef IQ2_XXS +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + for (var ib: u32 = 0; ib < 32; ib += 4) { + let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); + let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; + for (var l: u32 = 0; l < 4; l++) { + let ig = get_byte(aux0, l) * 8; + let is = (aux1 >> (7 * l)) & 127; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + dst[dst_i] = db * f32(g) * m; + dst_i++; + } + } + } +} +#endif + +#ifdef IQ2_XS +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + var scale_vals = array( + bitcast(vec2(block.scales[0], block.scales[1])), + bitcast(vec2(block.scales[2], block.scales[3])) + ); + for (var ib: u32 = 0; ib < 32; ib += 4) { + let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); + let db = array( + d * (0.5 + f32(s & 0xF)) * 0.25, + d * (0.5 + f32(s >> 4)) * 0.25 + ); + for (var l: u32 = 0; l < 4; l++) { + let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let ig = (qs_val & 511) * 8; + let is = qs_val >> 9; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + let dl = db[l/2]; + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + dst[dst_i] = dl * f32(g) * m; + dst_i++; + } + } + } +} +#endif + +#ifdef IQ2_S +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + var qs_vals : array; + for (var i: u32 = 0; i < 16; i++) { + qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + } + var qh_vals = array( + bitcast(vec2(block.qh[0], block.qh[1])), + bitcast(vec2(block.qh[2], block.qh[3])) + ); + var scale_vals = array( + bitcast(vec2(block.scales[0], block.scales[1])), + bitcast(vec2(block.scales[2], block.scales[3])) + ); + for (var ib: u32 = 0; ib < 8; ib ++) { + let s = get_byte(scale_vals[ib / 4], ib % 4); + let db = array( + d * (0.5 + f32(s & 0xF)) * 0.25, + d * (0.5 + f32(s >> 4)) * 0.25 + ); + let qs_w = qs_vals[ib]; + for (var l: u32 = 0; l < 4; l++) { + let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; + let ig = (get_byte(qs_w, l) | qh_b) * 8; + let signs = get_byte(qs_vals[ib + 8], l); + let dl = db[l/2]; + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + dst[dst_i] = dl * f32(g) * m; + dst_i++; + } + } + } +} +#endif + +#ifdef IQ3_XXS +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + for (var ib: u32 = 0; ib < 16; ib += 2) { + let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; + for (var l: u32 = 0; l < 4; l++) { + let is = (sc_sign >> (7 * l)) & 127; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig1 = get_byte(ig_val, 0); + let ig2 = get_byte(ig_val, 1); + for (var j: u32 = 0; j < 4; j++) { + let g1 = get_byte(iq3xxs_grid[ig1], j); + let g2 = get_byte(iq3xxs_grid[ig2], j); + let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); + let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); + dst[dst_i] = db * f32(g1) * m1; + dst[dst_i + 4] = db * f32(g2) * m2; + dst_i++; + } + dst_i += 4; + } + } +} +#endif + +#ifdef IQ3_S +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + var qh_vals = array( + bitcast(vec2(block.qh[0], block.qh[1])), + bitcast(vec2(block.qh[2], block.qh[3])) + ); + var sign_vals: array; + for (var i: u32 = 0; i < 8; i++) { + sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + } + var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + for (var ib: u32 = 0; ib < 4; ib++) { + let s = get_byte(scale_vals, ib); + let db = array( + d * (1.0 + 2.0 * f32(s & 0xF)), + d * (1.0 + 2.0 * f32(s >> 4)) + ); + for (var k: u32 = 0; k < 2; k++) { + let dl = db[k]; + let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); + let sign_w = sign_vals[ib * 2 + k]; + for (var l: u32 = 0; l < 4; l++) { + let signs = get_byte(sign_w, l); + let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); + let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); + for (var j: u32 = 0; j < 4; j++) { + let g1 = get_byte(iq3s_grid[ig1], j); + let g2 = get_byte(iq3s_grid[ig2], j); + let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); + let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); + dst[dst_i] = dl * f32(g1) * m1; + dst[dst_i + 4] = dl * f32(g2) * m2; + dst_i++; + } + dst_i += 4; + } + } + } +} +#endif + +#ifdef IQ1_S +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + for (var ib: u32 = 0; ib < 8; ib++) { + let qh = bitcast(vec2(block.qh[ib], 0.0)); + let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); + let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + for (var l: u32 = 0; l < 4; l++) { + let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; + for (var j: u32 = 0; j < 8; j++) { + let gw = iq1_grid[(ig + j) / 16]; + let g = (gw >> (((ig + j) % 16) * 2)) & 3; + let gs = bitcast(g << 30) >> 30; + dst[dst_i] = dl * (f32(gs) + delta); + dst_i++; + } + } + } +} +#endif + +#ifdef IQ1_M +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + + let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); + let d = f32(bitcast>(scale).x); + var dst_i = dst_base + offset * 256; + for (var ib: u32 = 0; ib < 8; ib++) { + let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; + let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; + let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; + var dl = array( + d * f32(2 * s1 + 1), + d * f32(2 * s2 + 1) + ); + + let qh = block.qh[ib / 2] >> (16 * (ib % 2)); + var idx = array( + get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), + get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), + get_byte(block.qs[ib], 2) | ((qh) & 0x700), + get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) + ); + var delta = array( + select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), + select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), + select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), + select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) + ); + for (var l: u32 = 0; l < 4; l++) { + let ig = idx[l] * 8; + for (var j: u32 = 0; j < 8; j++) { + let gw = iq1_grid[(ig + j) / 16]; + let g = (gw >> (((ig + j) % 16) * 2)) & 3; + let gs = bitcast(g << 30) >> 30; + dst[dst_i] = dl[l/2] * (f32(gs) + delta[l]); + dst_i++; + } + } + } +} +#endif + +#ifdef IQ4_NL +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 32; + var qs: array; + for (var i: u32 = 0; i < 4; i++) { + qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + } + for (var j: u32 = 0; j < 16; j++) { + let qsb = get_byte(qs[j / 4], j % 4); + dst[dst_i] = d * f32(kvalues_iq4nl[qsb & 0xF]); + dst[dst_i + 16] = d * f32(kvalues_iq4nl[qsb >> 4]); + dst_i++; + } +} +#endif + +#ifdef IQ4_XS +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + let scales_h = bitcast(vec2(block.scales_h, 0.0)); + var dst_i = dst_base + offset * 256; + for (var ib: u32 = 0; ib < 8; ib++) { + let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); + let dl = d * (f32(ls) - 32.0); + for (var j: u32 = 0; j < 16; j++) { + let iqs = ib * 16 + j; + let qsb = get_byte(block.qs[iqs / 4], iqs % 4); + dst[dst_i] = dl * f32(kvalues_iq4nl[qsb & 0xF]); + dst[dst_i + 16] = dl * f32(kvalues_iq4nl[qsb >> 4]); + dst_i++; + } + dst_i += 16; + } +} +#endif + +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var idx: array; + +@group(0) @binding(2) +var dst: array; + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of dst + ne0: u32, + n_rows: u32, + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(3) +var params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.n_rows * params.ne2 * params.ne3) { + return; + } + var i = gid.x; + let i_dst3 = i / (params.ne2 * params.n_rows); + + i = i % (params.ne2 * params.n_rows); + let i_dst2 = i / params.n_rows; + let i_dst1 = i % params.n_rows; + + let i_idx2 = i_dst3 % params.idx2; + let i_idx1 = i_dst2 % params.idx1; + let i_idx0 = i_dst1; + + let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + + let idx_val = u32(idx[i_idx]); + + let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; + let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; + + for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) { + copy_elements(i_src_row, i_dst_row, i); + } +} + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl deleted file mode 100644 index 0f8e6e5a..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ /dev/null @@ -1,907 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_1", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_1", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] - }, - { - "REPLS": { - "SRC0_TYPE": "q8_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q2_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q3_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q6_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_xxs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_xs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq3_xxs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq3_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq1_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq1_m", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq4_nl", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq4_xs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(FLOAT) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); -} -#enddecl(FLOAT) - -#decl(Q4_0) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_0 = src0[src0_idx_base + offset]; - let d = f32(block_q4_0.d); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#enddecl(Q4_0) - -#decl(Q4_1) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_1 = src0[src0_idx_base + offset]; - let d = f32(block_q4_1.d); - let m = f32(block_q4_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q4_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#enddecl(Q4_1) - -#decl(Q5_0) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_0 = src0[src0_idx_base + offset]; - let d = f32(block_q5_0.d); - var sum: f32 = 0.0; - let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#enddecl(Q5_0) - -#decl(Q5_1) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_1 = src0[src0_idx_base + offset]; - let d = f32(block_q5_1.d); - let m = f32(block_q5_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q5_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; - let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#enddecl(Q5_1) - -#decl(Q8_0) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_0 = src0[src0_idx_base + offset]; - let d = f32(block_q8_0.d); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#enddecl(Q8_0) - -#decl(Q8_1) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_1 = src0[src0_idx_base + offset]; - let d = f32(block_q8_1.d); - let m = f32(block_q8_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_packed = block_q8_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#enddecl(Q8_1) - -#decl(Q2_K) -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(block.scales[is / 4], is % 4); - is++; - let dl = d * f32(sc & 0xF); - let ml = m * f32(sc >> 4); - for (var l: u32 = 0u; l < 16; l++) { - let q_idx = q_b_idx + k + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 3; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - } - return sum; -} - -#enddecl(Q2_K) - -#decl(Q3_K) -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes - let kmask1: u32 = 0x03030303; - let kmask2: u32 = 0x0f0f0f0f; - var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - - // convert arrays of f16 -> u32 - var hmask_vals: array; - for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); - } - var qs_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var m: u32 = 1; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(scale_vals[is / 4], is % 4); - is++; - let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; - let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3; - sum += ((f32(qs_val) - hm) * dl) * src1[src1_i]; - src1_i++; - } - } - m <<= 1; - } - } - return sum; -} - -#enddecl(Q3_K) - -#decl(Q4_K) -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 0xF; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - -#enddecl(Q4_K) - -#decl(Q5_K) -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var u: u32 = 1; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qh_byte = get_byte(block.qh[l / 4], l % 4); - let qs_val = (q_byte >> shift) & 0xF; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); - sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - u <<= 1; - } - } - return sum; -} - -#enddecl(Q5_K) - -#decl(Q6_K) -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - - // convert arrays of f16 -> u32 - var ql_vals: array; - for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); - } - var qh_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); - } - var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var qh_b_idx: u32 = 0; - var sc_b_idx: u32 = 0; - for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { - for (var l: u32 = 0; l < 32; l++) { - let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); - let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); - let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); - - let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; - let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; - let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; - let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; - - let is = l/16; - let is1 = sc_b_idx + is; - let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); - let is2 = sc_b_idx + is + 2; - let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); - let is3 = sc_b_idx + is + 4; - let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); - let is4 = sc_b_idx + is + 6; - let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); - - sum += d * f32(sc1) * q1 * src1[src1_i + l]; - sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; - sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; - sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; - } - src1_i += 128; - qh_b_idx += 32; - sc_b_idx += 8; - } - return sum; -} - -#enddecl(Q6_K) - -#decl(IQ2_XXS) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); - let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; - for (var l: u32 = 0; l < 4; l++) { - let ig = get_byte(aux0, l) * 8; - let is = (aux1 >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += db * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - -#enddecl(IQ2_XXS) - -#decl(IQ2_XS) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); - let db = array( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); - let ig = (qs_val & 511) * 8; - let is = qs_val >> 9; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - -#enddecl(IQ2_XS) - -#decl(IQ2_S) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var qs_vals : array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); - } - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib ++) { - let s = get_byte(scale_vals[ib / 4], ib % 4); - let db = array( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - let qs_w = qs_vals[ib]; - for (var l: u32 = 0; l < 4; l++) { - let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; - let ig = (get_byte(qs_w, l) | qh_b) * 8; - let signs = get_byte(qs_vals[ib + 8], l); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - - -#enddecl(IQ2_S) - -#decl(IQ3_XSS) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); - let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; - for (var l: u32 = 0; l < 4; l++) { - let is = (sc_sign >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); - let ig1 = get_byte(ig_val, 0); - let ig2 = get_byte(ig_val, 1); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3xxs_grid[ig1], j); - let g2 = get_byte(iq3xxs_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += db * f32(g1) * m1 * src1[src1_i]; - sum += db * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - return sum; -} - -#enddecl(IQ3_XSS) - -#decl(IQ3_S) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var sign_vals: array; - for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); - } - var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); - var sum = 0.0; - for (var ib: u32 = 0; ib < 4; ib++) { - let s = get_byte(scale_vals, ib); - let db = array( - d * (1.0 + 2.0 * f32(s & 0xF)), - d * (1.0 + 2.0 * f32(s >> 4)) - ); - for (var k: u32 = 0; k < 2; k++) { - let dl = db[k]; - let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); - let sign_w = sign_vals[ib * 2 + k]; - for (var l: u32 = 0; l < 4; l++) { - let signs = get_byte(sign_w, l); - let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); - let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); - let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3s_grid[ig1], j); - let g2 = get_byte(iq3s_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += dl * f32(g1) * m1 * src1[src1_i]; - sum += dl * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - } - return sum; -} -#enddecl(IQ3_S) - -#decl(IQ1_S) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); - for (var l: u32 = 0; l < 4; l++) { - let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast(g << 30) >> 30; - sum += dl * (f32(gs) + delta) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - -#enddecl(IQ1_S) - -#decl(IQ1_M) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - - let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); - let d = f32(bitcast>(scale).x); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; - let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; - let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; - var dl = array( - d * f32(2 * s1 + 1), - d * f32(2 * s2 + 1) - ); - - let qh = block.qh[ib / 2] >> (16 * (ib % 2)); - var idx = array( - get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), - get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), - get_byte(block.qs[ib], 2) | ((qh) & 0x700), - get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) - ); - var delta = array( - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) - ); - for (var l: u32 = 0; l < 4; l++) { - let ig = idx[l] * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast(g << 30) >> 30; - sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - -#enddecl(IQ1_M) - -#decl(IQ4_NL) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 32; - var sum = 0.0; - var qs: array; - for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); - } - for (var j: u32 = 0; j < 16; j++) { - let qsb = get_byte(qs[j / 4], j % 4); - sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - return sum; -} - -#enddecl(IQ4_NL) - -#decl(IQ4_XS) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let scales_h = bitcast(vec2(block.scales_h, 0.0)); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); - let dl = d * (f32(ls) - 32.0); - for (var j: u32 = 0; j < 16; j++) { - let iqs = ib * 16 + j; - let qsb = get_byte(block.qs[iqs / 4], iqs % 4); - sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - src1_i += 16; - } - return sum; -} - -#enddecl(IQ4_XS) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -DECLS - -struct MulMatParams { - offset_src0: u32, // in elements/blocks - offset_src1: u32, // in elements/blocks - offset_dst: u32, // in elements/blocks - m: u32, - n: u32, - k: u32, - // all strides are in elements/blocks - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns - -@group(0) @binding(3) var params: MulMatParams; - -@compute @workgroup_size(256) -fn main(@builtin(global_invocation_id) global_id: vec3) { - let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_id.x >= total) { - return; - } - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = global_id.x / dst3_stride; - let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension - let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_id.x % dst3_stride; - - let dst2_idx = dst3_rem / dst2_stride; - let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension - let src12_idx = dst2_idx; // src1 is not broadcast - - let dst2_rem = dst3_rem % dst2_stride; - - let row = dst2_rem / params.m; // output row - let col = dst2_rem % params.m; // output column - - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; - - var sum = 0.0; - for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { - sum += multiply_add(src0_idx_base, src1_idx_base, i); - } - dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl new file mode 100644 index 00000000..6aba4731 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -0,0 +1,713 @@ +enable f16; + +#include "common_decls.tmpl" + +#ifdef FLOAT +const BLOCK_SIZE = 1u; + +#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL) +const BLOCK_SIZE = 32u; + +#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS) +const BLOCK_SIZE = 256u; +#endif + +#ifdef FLOAT +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); +} +#endif + +#ifdef Q4_0 +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q4_0 = src0[src0_idx_base + offset]; + let d = f32(block_q4_0.d); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 4; j++) { + let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; + let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#endif + +#ifdef Q4_1 +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q4_1 = src0[src0_idx_base + offset]; + let d = f32(block_q4_1.d); + let m = f32(block_q4_1.m); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 4; j++) { + let q_packed = block_q4_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32((q_byte >> 4) & 0xF) * d + m; + let q_lo = f32(q_byte & 0xF) * d + m; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#endif + +#ifdef Q5_0 +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q5_0 = src0[src0_idx_base + offset]; + let d = f32(block_q5_0.d); + var sum: f32 = 0.0; + let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; + let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; + let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#endif + +#ifdef Q5_1 +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q5_1 = src0[src0_idx_base + offset]; + let d = f32(block_q5_1.d); + let m = f32(block_q5_1.m); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 4; j++) { + let q_packed = block_q5_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; + let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; + let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; + let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#endif + +#ifdef Q8_0 +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q8_0 = src0[src0_idx_base + offset]; + let d = f32(block_q8_0.d); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 8; j++) { + let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_val * f32(src1[src1_offset]); + } + } + return sum; +} +#endif + +#ifdef Q8_1 +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q8_1 = src0[src0_idx_base + offset]; + let d = f32(block_q8_1.d); + let m = f32(block_q8_1.m); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 8; j++) { + let q_packed = block_q8_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d + m; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_val * f32(src1[src1_offset]); + } + } + return sum; +} +#endif + +#ifdef Q2_K +// 16 blocks of 16 elements each +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + // 2 halves of the block (128 elements each) + for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { + // 4 groups (each group has 2 blocks of 16 elements) + for (var shift: u32 = 0; shift < 8; shift += 2) { + // 2 blocks + for (var k: u32 = 0; k < 32; k += 16) { + let sc = get_byte(block.scales[is / 4], is % 4); + is++; + let dl = d * f32(sc & 0xF); + let ml = m * f32(sc >> 4); + for (var l: u32 = 0u; l < 16; l++) { + let q_idx = q_b_idx + k + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qs_val = (q_byte >> shift) & 3; + sum += (f32(qs_val) * dl - ml) * src1[src1_i]; + src1_i++; + } + } + } + } + return sum; +} +#endif + +#ifdef Q3_K +// 16 blocks of 16 elements each +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + + // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, + // and 2-bits from the last 4 bytes + let kmask1: u32 = 0x03030303; + let kmask2: u32 = 0x0f0f0f0f; + var scale_vals: array; + for (var i: u32 = 0; i < 4; i++) { + scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + } + var tmp: u32 = scale_vals[2]; + scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); + scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + // convert arrays of f16 -> u32 + var hmask_vals: array; + for (var i: u32 = 0; i < 8; i++) { + hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + } + var qs_vals: array; + for (var i: u32 = 0; i < 16; i++) { + qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + } + + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + var m: u32 = 1; + // 2 halves of the block (128 elements each) + for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { + // 4 groups (each group has 2 blocks of 16 elements) + for (var shift: u32 = 0; shift < 8; shift += 2) { + // 2 blocks + for (var k: u32 = 0; k < 32; k += 16) { + let sc = get_byte(scale_vals[is / 4], is % 4); + is++; + let dl = d * (f32(sc) - 32.0); + for (var l: u32 = 0u; l < 16u; l++) { + let q_idx = q_b_idx + k + l; + let hm_idx = k + l; + let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); + let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); + let qs_val = (q_byte >> shift) & 3; + sum += ((f32(qs_val) - hm) * dl) * src1[src1_i]; + src1_i++; + } + } + m <<= 1; + } + } + return sum; +} +#endif + +#ifdef Q4_K +// 8 blocks of 32 elements each +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + // 2 blocks each iteration + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { + for (var shift: u32 = 0; shift < 8; shift += 4) { + let scale_min = get_scale_min(is, block.scales); + is++; + let dl = d * scale_min.x; + let ml = m * scale_min.y; + for (var l: u32 = 0; l < 32; l++) { + let q_idx = q_b_idx + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qs_val = (q_byte >> shift) & 0xF; + sum += (f32(qs_val) * dl - ml) * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} +#endif + +#ifdef Q5_K +// 8 blocks of 32 elements each +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + var u: u32 = 1; + // 2 blocks each iteration + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { + for (var shift: u32 = 0; shift < 8; shift += 4) { + let scale_min = get_scale_min(is, block.scales); + is++; + let dl = d * scale_min.x; + let ml = m * scale_min.y; + for (var l: u32 = 0; l < 32; l++) { + let q_idx = q_b_idx + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qh_byte = get_byte(block.qh[l / 4], l % 4); + let qs_val = (q_byte >> shift) & 0xF; + let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); + sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; + src1_i++; + } + u <<= 1; + } + } + return sum; +} +#endif + +#ifdef Q6_K +// 16 blocks of 16 elements each +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + + // convert arrays of f16 -> u32 + var ql_vals: array; + for (var i: u32 = 0; i < 32; i++) { + ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + } + var qh_vals: array; + for (var i: u32 = 0; i < 16; i++) { + qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + } + var scale_vals: array; + for (var i: u32 = 0; i < 4; i++) { + scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + } + + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var qh_b_idx: u32 = 0; + var sc_b_idx: u32 = 0; + for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { + for (var l: u32 = 0; l < 32; l++) { + let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); + let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); + let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); + + let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; + let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; + let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; + let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; + + let is = l/16; + let is1 = sc_b_idx + is; + let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); + let is2 = sc_b_idx + is + 2; + let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); + let is3 = sc_b_idx + is + 4; + let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); + let is4 = sc_b_idx + is + 6; + let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); + + sum += d * f32(sc1) * q1 * src1[src1_i + l]; + sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; + sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; + sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; + } + src1_i += 128; + qh_b_idx += 32; + sc_b_idx += 8; + } + return sum; +} +#endif + +#ifdef IQ2_XXS +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 32; ib += 4) { + let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); + let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; + for (var l: u32 = 0; l < 4; l++) { + let ig = get_byte(aux0, l) * 8; + let is = (aux1 >> (7 * l)) & 127; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + sum += db * f32(g) * m * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} +#endif + +#ifdef IQ2_XS +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var scale_vals = array( + bitcast(vec2(block.scales[0], block.scales[1])), + bitcast(vec2(block.scales[2], block.scales[3])) + ); + var sum = 0.0; + for (var ib: u32 = 0; ib < 32; ib += 4) { + let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); + let db = array( + d * (0.5 + f32(s & 0xF)) * 0.25, + d * (0.5 + f32(s >> 4)) * 0.25 + ); + for (var l: u32 = 0; l < 4; l++) { + let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let ig = (qs_val & 511) * 8; + let is = qs_val >> 9; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + let dl = db[l/2]; + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + sum += dl * f32(g) * m * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} +#endif + +#ifdef IQ2_S +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var qs_vals : array; + for (var i: u32 = 0; i < 16; i++) { + qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + } + var qh_vals = array( + bitcast(vec2(block.qh[0], block.qh[1])), + bitcast(vec2(block.qh[2], block.qh[3])) + ); + var scale_vals = array( + bitcast(vec2(block.scales[0], block.scales[1])), + bitcast(vec2(block.scales[2], block.scales[3])) + ); + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib ++) { + let s = get_byte(scale_vals[ib / 4], ib % 4); + let db = array( + d * (0.5 + f32(s & 0xF)) * 0.25, + d * (0.5 + f32(s >> 4)) * 0.25 + ); + let qs_w = qs_vals[ib]; + for (var l: u32 = 0; l < 4; l++) { + let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; + let ig = (get_byte(qs_w, l) | qh_b) * 8; + let signs = get_byte(qs_vals[ib + 8], l); + let dl = db[l/2]; + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + sum += dl * f32(g) * m * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} +#endif + +#ifdef IQ3_XXS +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 16; ib += 2) { + let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; + for (var l: u32 = 0; l < 4; l++) { + let is = (sc_sign >> (7 * l)) & 127; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig1 = get_byte(ig_val, 0); + let ig2 = get_byte(ig_val, 1); + for (var j: u32 = 0; j < 4; j++) { + let g1 = get_byte(iq3xxs_grid[ig1], j); + let g2 = get_byte(iq3xxs_grid[ig2], j); + let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); + let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); + sum += db * f32(g1) * m1 * src1[src1_i]; + sum += db * f32(g2) * m2 * src1[src1_i + 4]; + src1_i++; + } + src1_i += 4; + } + } + return sum; +} +#endif + +#ifdef IQ3_S +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var qh_vals = array( + bitcast(vec2(block.qh[0], block.qh[1])), + bitcast(vec2(block.qh[2], block.qh[3])) + ); + var sign_vals: array; + for (var i: u32 = 0; i < 8; i++) { + sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + } + var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + var sum = 0.0; + for (var ib: u32 = 0; ib < 4; ib++) { + let s = get_byte(scale_vals, ib); + let db = array( + d * (1.0 + 2.0 * f32(s & 0xF)), + d * (1.0 + 2.0 * f32(s >> 4)) + ); + for (var k: u32 = 0; k < 2; k++) { + let dl = db[k]; + let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); + let sign_w = sign_vals[ib * 2 + k]; + for (var l: u32 = 0; l < 4; l++) { + let signs = get_byte(sign_w, l); + let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); + let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); + for (var j: u32 = 0; j < 4; j++) { + let g1 = get_byte(iq3s_grid[ig1], j); + let g2 = get_byte(iq3s_grid[ig2], j); + let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); + let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); + sum += dl * f32(g1) * m1 * src1[src1_i]; + sum += dl * f32(g2) * m2 * src1[src1_i + 4]; + src1_i++; + } + src1_i += 4; + } + } + } + return sum; +} +#endif + +#ifdef IQ1_S +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib++) { + let qh = bitcast(vec2(block.qh[ib], 0.0)); + let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); + let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + for (var l: u32 = 0; l < 4; l++) { + let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; + for (var j: u32 = 0; j < 8; j++) { + let gw = iq1_grid[(ig + j) / 16]; + let g = (gw >> (((ig + j) % 16) * 2)) & 3; + let gs = bitcast(g << 30) >> 30; + sum += dl * (f32(gs) + delta) * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} +#endif + + +#ifdef IQ1_M +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + + let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); + let d = f32(bitcast>(scale).x); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib++) { + let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; + let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; + let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; + var dl = array( + d * f32(2 * s1 + 1), + d * f32(2 * s2 + 1) + ); + + let qh = block.qh[ib / 2] >> (16 * (ib % 2)); + var idx = array( + get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), + get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), + get_byte(block.qs[ib], 2) | ((qh) & 0x700), + get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) + ); + var delta = array( + select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), + select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), + select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), + select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) + ); + for (var l: u32 = 0; l < 4; l++) { + let ig = idx[l] * 8; + for (var j: u32 = 0; j < 8; j++) { + let gw = iq1_grid[(ig + j) / 16]; + let g = (gw >> (((ig + j) % 16) * 2)) & 3; + let gs = bitcast(g << 30) >> 30; + sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} +#endif + +#ifdef IQ4_NL +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 32; + var sum = 0.0; + var qs: array; + for (var i: u32 = 0; i < 4; i++) { + qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + } + for (var j: u32 = 0; j < 16; j++) { + let qsb = get_byte(qs[j / 4], j % 4); + sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; + sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; + src1_i++; + } + return sum; +} +#endif + +#ifdef IQ4_XS +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let scales_h = bitcast(vec2(block.scales_h, 0.0)); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib++) { + let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); + let dl = d * (f32(ls) - 32.0); + for (var j: u32 = 0; j < 16; j++) { + let iqs = ib * 16 + j; + let qsb = get_byte(block.qs[iqs / 4], iqs % 4); + sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; + sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; + src1_i++; + } + src1_i += 16; + } + return sum; +} +#endif + +struct MulMatParams { + offset_src0: u32, // in elements/blocks + offset_src1: u32, // in elements/blocks + offset_dst: u32, // in elements/blocks + m: u32, + n: u32, + k: u32, + // all strides are in elements/blocks + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns + +@group(0) @binding(3) var params: MulMatParams; + +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (global_id.x >= total) { + return; + } + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = global_id.x / dst3_stride; + let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension + let src13_idx = dst3_idx; // src1 is not broadcast + let dst3_rem = global_id.x % dst3_stride; + + let dst2_idx = dst3_rem / dst2_stride; + let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension + let src12_idx = dst2_idx; // src1 is not broadcast + + let dst2_rem = dst3_rem % dst2_stride; + + let row = dst2_rem / params.m; // output row + let col = dst2_rem % params.m; // output column + + let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; + + var sum = 0.0; + for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) { + sum += multiply_add(src0_idx_base, src1_idx_base, i); + } + dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 109ff8d6..5c1074eb 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -1,58 +1,65 @@ -#decl(SHMEM_VEC) +#ifdef VEC +#define VEC_SIZE 4 +#define SHMEM_TYPE vec4 +#define DST_TYPE vec4 +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + fn store_shmem(val: vec4, idx: u32) { shmem[idx] = val.x; shmem[idx + 1] = val.y; shmem[idx + 2] = val.z; shmem[idx + 3] = val.w; } -#enddecl(SHMEM_VEC) +#endif + +#ifdef SCALAR +#define VEC_SIZE 1 +#define SHMEM_TYPE f16 +#define DST_TYPE f32 +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE -#decl(SHMEM_SCALAR) fn store_shmem(val: f16, idx: u32) { shmem[idx] = val; } -#enddecl(SHMEM_SCALAR) - -#decl(INIT_SRC0_SHMEM_FLOAT) +#endif +#ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_m = offset_m + tile_m; let global_k = k_outer + tile_k; let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let src0_val = select( // taking a slight performance hit to avoid oob - {{SRC0_TYPE}}(0.0), - src0[src0_idx/{{VEC_SIZE}}], + SRC0_TYPE(0.0), + src0[src0_idx/VEC_SIZE], global_m < params.m && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); + store_shmem(SHMEM_TYPE(src0_val), elem_idx); } } +#endif -#enddecl(INIT_SRC0_SHMEM_FLOAT) - -#decl(INIT_SRC1_SHMEM) - +#ifdef INIT_SRC1_SHMEM_FLOAT fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let tile_n = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_n = offset_n + tile_n; let global_k = k_outer + tile_k; let src1_idx = batch_offset + global_n * params.stride_11 + global_k; let src1_val = select( - {{SRC1_TYPE}}(0.0), - src1[src1_idx/{{VEC_SIZE}}], + SRC1_TYPE(0.0), + src1[src1_idx/VEC_SIZE], global_n < params.n && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); + store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); } } +#endif -#enddecl(INIT_SRC1_SHMEM) - -#decl(INIT_SRC0_SHMEM_Q4_0) - +#ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; @@ -93,5 +100,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } } - -#enddecl(INIT_SRC0_SHMEM_Q4_0) +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl deleted file mode 100644 index 6b1dd26c..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ /dev/null @@ -1,247 +0,0 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32_vec", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(VEC) -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { - return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); -} -#enddecl(VEC) - -#decl(SCALAR) -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { - return f32(acc[tm][tn]); -} -#enddecl(SCALAR) - -#end(DECLS) - -#define(SHADER) -enable f16; - -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) - -@group(0) @binding(3) var params: MulMatParams; - -DECLS - -fn get_local_n(thread_id: u32) -> u32 { - return thread_id / WORKGROUP_SIZE_M; -} -fn get_local_m(thread_id: u32) -> u32 { - return thread_id % WORKGROUP_SIZE_M; -} - -// TILE_M must be multiple of 4 for vec4 loads -const TILE_M = {{WEBGPU_TILE_M}}u; -const TILE_N = {{WEBGPU_TILE_N}}u; - -override WORKGROUP_SIZE_M: u32; -override WORKGROUP_SIZE_N: u32; -override TILE_K: u32; - -override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; -override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; -override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; - -var shmem: array; - -@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) -fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3) { - - let thread_id = local_id.x; - let local_m = get_local_m(thread_id); - let local_n = get_local_n(thread_id); - - let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); - let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); - let wg_per_matrix = wg_m_count * wg_n_count; - - let batch_idx = wg_id.x / wg_per_matrix; - - let wg_in_batch = wg_id.x % wg_per_matrix; - let wg_m = wg_in_batch % wg_m_count; - let wg_n = wg_in_batch / wg_m_count; - - let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; - let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; - - let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; - let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - - let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; - let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; - - var acc: array, TILE_M>; - - for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { - - // see mul_mat_decls.tmpl - init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); - init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); - - workgroupBarrier(); - - let k_end = min(TILE_K, params.k - k_outer); - - for (var k_inner = 0u; k_inner < k_end; k_inner++) { - var src0_tile: array; - for (var tm = 0u; tm < TILE_M; tm++) { - let src0_m = local_m * TILE_M + tm; - let src0_idx = k_inner + src0_m * TILE_K; - src0_tile[tm] = shmem[src0_idx]; - } - for (var tn = 0u; tn < TILE_N; tn++) { - let src1_n = local_n * TILE_N + tn; - let src1_idx = src1_n * TILE_K + k_inner; - let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; - for (var tm = 0u; tm < TILE_M; tm++) { - acc[tm][tn] += src0_tile[tm] * src1_val; - } - } - } - - workgroupBarrier(); - } - - let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; - - for (var tn = 0u; tn < TILE_N; tn++) { - let global_col = output_col_base + tn; - if (global_col < params.n) { - for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { - let global_row = output_row_base + tm; - if (global_row < params.m) { - let dst_idx = dst_batch_offset + global_col * params.m + global_row; - dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); - } - } - } - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl new file mode 100644 index 00000000..771e5cd1 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -0,0 +1,138 @@ +enable f16; + +#include "common_decls.tmpl" +#include "mul_mat_decls.tmpl" + +#ifdef VEC +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); +} +#endif + +#ifdef SCALAR +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return f32(acc[tm][tn]); +} +#endif + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; +} +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; +} + +const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; +var shmem: array; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + + let thread_id = local_id.x; + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); + + let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_per_matrix = wg_m_count * wg_n_count; + + let batch_idx = wg_id.x / wg_per_matrix; + + let wg_in_batch = wg_id.x % wg_per_matrix; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; + + let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; + let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; + let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; + + var acc: array, TILE_M>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + // see mul_mat_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); + + workgroupBarrier(); + + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner++) { + var src0_tile: array; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = shmem[src0_idx]; + } + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; + let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tm][tn] += src0_tile[tm] * src1_val; + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + for (var tn = 0u; tn < TILE_N; tn++) { + let global_col = output_col_base + tn; + if (global_col < params.n) { + for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) { + let global_row = output_row_base + tm; + if (global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; + dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm); + } + } + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl deleted file mode 100644 index 47c8ce36..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ /dev/null @@ -1,302 +0,0 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32_vec", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(VEC) -fn store_dst(shmem_idx: u32, dst_idx: u32) { - dst[dst_idx] = vec4( - f32(shmem[shmem_idx]), - f32(shmem[shmem_idx + 1]), - f32(shmem[shmem_idx + 2]), - f32(shmem[shmem_idx + 3]) - ); -} -#enddecl(VEC) - -#decl(SCALAR) -fn store_dst(shmem_idx: u32, dst_idx: u32) { - dst[dst_idx] = f32(shmem[shmem_idx]); -} -#enddecl(SCALAR) - -#end(DECLS) - -#define(SHADER) -diagnostic(off, chromium.subgroup_matrix_uniformity); -enable f16; -enable subgroups; -enable chromium_experimental_subgroup_matrix; - -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) - -@group(0) @binding(3) var params: MulMatParams; - -DECLS - -// Note: These are string interpolated at build time, cannot use override constants due to limitations in -// current Dawn version type definitions/matrix load requirements for constant memory sizes. -const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; -const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; -// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the -// runtime subgroup size is smaller. -const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u; - -const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; - -const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; -const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; -const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u; - -const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u; -const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u; - -const TILE_K = {{WEBGPU_TILE_K}}u; - -const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; -const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - -const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; -const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; -const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - -const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE; - -// We reuse shmem for accumulation matrices -const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM); - -var shmem: array; - -@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) -fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(subgroup_id) subgroup_id: u32) { - - let thread_id = local_id.x; - let subgroup_m = subgroup_id % SUBGROUP_M; - let subgroup_n = subgroup_id / SUBGROUP_M; - - let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE; - let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; - let wg_per_matrix = wg_m_count * wg_n_count; - - let batch_idx = wg_id.x / wg_per_matrix; - - let wg_in_batch = wg_id.x % wg_per_matrix; - let wg_m = wg_in_batch % wg_m_count; - let wg_n = wg_in_batch / wg_m_count; - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; - - let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; - let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - - let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; - let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - - var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; - - for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { - - // see mul_mat_decls.tmpl - init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); - init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); - - workgroupBarrier(); - - if (subgroup_id < EXPECTED_SUBGROUPS) { - - for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { - - let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; - var src0_sg_mats: array, SUBGROUP_MATRIX_M>; - for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { - src0_sg_mats[m] = subgroupMatrixLoad>( - &shmem, - src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, - false, - TILE_K - ); - } - - let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; - for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { - let src1_sg_mat = subgroupMatrixLoad>( - &shmem, - src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, - true, - TILE_K - ); - for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { - acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]); - } - } - } - } - - workgroupBarrier(); - } - - let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; - - // Stage the subgroup matrix tiles into shared memory - // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile). - let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE; - let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; - - if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :( - for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { - for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { - let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE; - let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE; - let out_base = local_row * WG_TILE_STRIDE + local_col; - subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); - } - } - } - - workgroupBarrier(); - - // Cooperative write: iterate over the entire workgroup tile - let tile_rows = WG_N_SG_TILE_SIZE; - let tile_cols = WG_M_SG_TILE_SIZE; - let total_tile_elems = tile_rows * tile_cols; - let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; - let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - - for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { - let local_row = idx % WG_TILE_STRIDE; - let local_col = idx / WG_TILE_STRIDE; - - let global_row = tile_dst_row_base + local_row; - let global_col = tile_dst_col_base + local_col; - - if (global_col < params.n && global_row < params.m) { - let dst_idx = dst_batch_offset + global_col * params.m + global_row; - store_dst(idx, dst_idx/{{VEC_SIZE}}); - } - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl new file mode 100644 index 00000000..64529e03 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -0,0 +1,188 @@ +diagnostic(off, chromium.subgroup_matrix_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#include "common_decls.tmpl" +#include "mul_mat_decls.tmpl" + +#ifdef VEC +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = vec4( + f32(shmem[shmem_idx]), + f32(shmem[shmem_idx + 1]), + f32(shmem[shmem_idx + 2]), + f32(shmem[shmem_idx + 3]) + ); +} +#endif + +#ifdef SCALAR +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = f32(shmem[shmem_idx]); +} +#endif + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + +// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the +// runtime subgroup size is smaller. +const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; +const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; +const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + +const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE; + +// We reuse shmem for accumulation matrices +const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM); + +var shmem: array; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32) { + + let thread_id = local_id.x; + let subgroup_m = subgroup_id % SUBGROUP_M; + let subgroup_n = subgroup_id / SUBGROUP_M; + + let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE; + let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; + let wg_per_matrix = wg_m_count * wg_n_count; + + let batch_idx = wg_id.x / wg_per_matrix; + + let wg_in_batch = wg_id.x % wg_per_matrix; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + + var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + // see mul_mat_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); + + workgroupBarrier(); + + if (subgroup_id < EXPECTED_SUBGROUPS) { + + for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { + + let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; + var src0_sg_mats: array, SUBGROUP_MATRIX_M>; + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + src0_sg_mats[m] = subgroupMatrixLoad>( + &shmem, + src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, + false, + TILE_K + ); + } + + let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + let src1_sg_mat = subgroupMatrixLoad>( + &shmem, + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, + true, + TILE_K + ); + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]); + } + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + // Stage the subgroup matrix tiles into shared memory + // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile). + let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE; + let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + + if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :( + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE; + let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE; + let out_base = local_row * WG_TILE_STRIDE + local_col; + subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); + } + } + } + + workgroupBarrier(); + + // Cooperative write: iterate over the entire workgroup tile + let tile_rows = WG_N_SG_TILE_SIZE; + let tile_cols = WG_M_SG_TILE_SIZE; + let total_tile_elems = tile_rows * tile_cols; + let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + + for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { + let local_row = idx % WG_TILE_STRIDE; + let local_col = idx / WG_TILE_STRIDE; + + let global_row = tile_dst_row_base + local_row; + let global_col = tile_dst_col_base + local_col; + + if (global_col < params.n && global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; + store_dst(idx, dst_idx/VEC_SIZE); + } + } +} + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl deleted file mode 100644 index ffbb6403..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +++ /dev/null @@ -1,267 +0,0 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(VEC) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); -} - -fn store_val(group_base: u32) -> vec4 { - return vec4(partial_sums[group_base], - partial_sums[group_base + THREADS_PER_OUTPUT], - partial_sums[group_base + THREADS_PER_OUTPUT * 2], - partial_sums[group_base + THREADS_PER_OUTPUT * 3]); -} -#enddecl(VEC) - -#decl(SCALAR) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return f32(src0_val) * f32(src1_val); -} - -fn store_val(group_base: u32) -> f32 { - return partial_sums[group_base]; -} -#enddecl(SCALAR) - -#decl(MUL_ACC_FLOAT) - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { - let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; - let b = shared_vector[i / {{VEC_SIZE}}]; - local_sum += inner_dot(a, b); - } - return local_sum; -} - -#enddecl(MUL_ACC_FLOAT) - -#decl(MUL_ACC_Q4_0) - -const BLOCK_SIZE = 32; -const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0) * d; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; - } - } - } - return local_sum; -} - -#enddecl(MUL_ACC_Q4_0) - -#end(DECLS) - -#define(SHADER) -enable f16; - -DECLS - -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // Result vector (transposed) - -@group(0) @binding(3) var params: MulMatParams; - -override WORKGROUP_SIZE: u32; -override TILE_K: u32; -override OUTPUTS_PER_WG: u32; -override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; - -// Shared memory for collaborative loading and reduction -var shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile -var partial_sums: array; // For reduction - -@compute @workgroup_size(WORKGROUP_SIZE) -fn main( - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { - let thread_id = local_id.x; - - // Handle batch dimensions - let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; - let batch_idx = wg_linear / output_groups; - if (batch_idx >= total_batches) { - return; - } - - // Which of the outputs does this thread belong to? - let thread_group = thread_id / THREADS_PER_OUTPUT; - let thread_in_group = thread_id % THREADS_PER_OUTPUT; - - // Each workgroup computes OUTPUTS_PER_WG consecutive outputs - let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; - - let dst2_stride = params.m * params.n; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; - - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; - - var local_sum = 0.0; - - // Each thread processes multiple K elements and accumulates - for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { - let tile_size = min(TILE_K, params.k - k_tile); - - // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { - shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; - } - - workgroupBarrier(); - - if (output_row < params.m) { - local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); - } - - workgroupBarrier(); - } - - // Store partial sums and reduce within each partition - partial_sums[thread_id] = local_sum; - workgroupBarrier(); - let group_base = thread_group * THREADS_PER_OUTPUT; - let thread_base = group_base + thread_in_group; - var offset = THREADS_PER_OUTPUT / 2; - while (offset > 0) { - if (thread_in_group < offset) { - partial_sums[thread_base] += partial_sums[thread_base + offset]; - } - offset = offset / 2; - workgroupBarrier(); - } - - // Store back to global memory - if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { - dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); - } -} -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl new file mode 100644 index 00000000..f9ea95e0 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -0,0 +1,194 @@ + +enable f16; + +#include "common_decls.tmpl" + +#ifdef VEC + +#define VEC_SIZE 4 +#define DST_TYPE vec4 +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(dot(SRC1_TYPE(src0_val), src1_val)); +} + +fn store_val(group_base: u32) -> vec4 { + return vec4(partial_sums[group_base], + partial_sums[group_base + THREADS_PER_OUTPUT], + partial_sums[group_base + THREADS_PER_OUTPUT * 2], + partial_sums[group_base + THREADS_PER_OUTPUT * 3]); +} +#endif + +#ifdef SCALAR + +#define VEC_SIZE 1 +#define DST_TYPE f32 +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(src0_val) * f32(src1_val); +} + +fn store_val(group_base: u32) -> f32 { + return partial_sums[group_base]; +} +#endif + +#ifdef MUL_ACC_FLOAT +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) { + let a = src0[(idx_base + k_outer + i) / VEC_SIZE]; + let b = shared_vector[i / VEC_SIZE]; + local_sum += inner_dot(a, b); + } + return local_sum; +} +#endif + +#ifdef MUL_ACC_Q4_0 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1 + block_offset + j]; + let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f32(q_byte & 0xF) - 8.0) * d; + local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + } + } + } + return local_sum; +} +#endif + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG; + +// Shared memory for collaborative loading and reduction +var shared_vector: array; // Cache vector tile +var partial_sums: array; // For reduction + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let thread_id = local_id.x; + + // Handle batch dimensions + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + // Which of the outputs does this thread belong to? + let thread_group = thread_id / THREADS_PER_OUTPUT; + let thread_in_group = thread_id % THREADS_PER_OUTPUT; + + // Each workgroup computes OUTPUTS_PER_WG consecutive outputs + let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; + + var local_sum = 0.0; + + // Each thread processes multiple K elements and accumulates + for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { + let tile_size = min(TILE_K, params.k - k_tile); + + // Cooperatively load vector tile into shared memory (all threads) + for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) { + shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE]; + } + + workgroupBarrier(); + + if (output_row < params.m) { + local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); + } + + workgroupBarrier(); + } + + // Store partial sums and reduce within each partition + partial_sums[thread_id] = local_sum; + workgroupBarrier(); + let group_base = thread_group * THREADS_PER_OUTPUT; + let thread_base = group_base + thread_in_group; + var offset: u32 = THREADS_PER_OUTPUT / 2; + while (offset > 0) { + if (thread_in_group < offset) { + partial_sums[thread_base] += partial_sums[thread_base + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + + // Store back to global memory + if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) { + dst[dst_idx / VEC_SIZE] = store_val(group_base); + } +} + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl deleted file mode 100644 index 040e80df..00000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +++ /dev/null @@ -1,90 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "scale_f32", - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "scale_f32_inplace", - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) -@group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) -var params: Params; - -fn store_scale(val: f32, offset: u32) { - dst[offset] = val; -} -#enddecl(NOT_INPLACE) - -#decl(INPLACE) -@group(0) @binding(1) -var params: Params; - -fn store_scale(val: f32, offset: u32) { - src[offset] = val; -} -#enddecl(INPLACE) - -#end(DECLS) - -#define(SHADER) - -struct Params { - offset_src: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - scale: f32, - bias: f32 -}; - -@group(0) @binding(0) -var src: array; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0; - let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; - - store_scale(src[i_src] * params.scale + params.bias, i_dst); -} -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl new file mode 100644 index 00000000..3b70a876 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl @@ -0,0 +1,63 @@ +#ifdef INPLACE +@group(0) @binding(1) +var params: Params; + +fn store_scale(val: f32, offset: u32) { + src[offset] = val; +} +#else +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn store_scale(val: f32, offset: u32) { + dst[offset] = val; +} +#endif + +struct Params { + offset_src: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + scale: f32, + bias: f32 +}; + +@group(0) @binding(0) +var src: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + store_scale(src[i_src] * params.scale + params.bias, i_dst); +}