uint32_t mul_mat_wg_size;
};
+/** Cpy **/
+
+struct ggml_webgpu_cpy_pipeline_key {
+ ggml_type src_type;
+ ggml_type dst_type;
+
+ bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
+ return src_type == other.src_type && dst_type == other.dst_type;
+ }
+};
+
+struct ggml_webgpu_cpy_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.src_type);
+ ggml_webgpu_hash_combine(seed, key.dst_type);
+ return seed;
+ }
+};
+
+/** Glu **/
+
+struct ggml_webgpu_glu_pipeline_key {
+ ggml_glu_op glu_op;
+ ggml_type type;
+ bool split;
+
+ bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
+ return glu_op == other.glu_op && type == other.type && split == other.split;
+ }
+};
+
+struct ggml_webgpu_glu_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.glu_op);
+ ggml_webgpu_hash_combine(seed, key.type);
+ ggml_webgpu_hash_combine(seed, key.split);
+ return seed;
+ }
+};
+
+/** Rope **/
+
+struct ggml_webgpu_rope_pipeline_key {
+ ggml_type type;
+ bool inplace;
+ bool has_ff;
+
+ bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
+ return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
+ }
+};
+
+struct ggml_webgpu_rope_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.type);
+ ggml_webgpu_hash_combine(seed, key.inplace);
+ ggml_webgpu_hash_combine(seed, key.has_ff);
+ return seed;
+ }
+};
+
+/** SoftMax **/
+
+struct ggml_webgpu_soft_max_pipeline_key {
+ ggml_type mask_type;
+ bool has_mask;
+ bool has_sink;
+ bool inplace;
+
+ bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
+ return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
+ inplace == other.inplace;
+ }
+};
+
+struct ggml_webgpu_soft_max_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.mask_type);
+ ggml_webgpu_hash_combine(seed, key.has_mask);
+ ggml_webgpu_hash_combine(seed, key.has_sink);
+ ggml_webgpu_hash_combine(seed, key.inplace);
+ return seed;
+ }
+};
+
class ggml_webgpu_shader_lib {
wgpu::Device device;
pre_wgsl::Preprocessor preprocessor;
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
+ std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
+ std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
+ std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
+ rope_pipelines;
+ std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
+ soft_max_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
return flash_attn_pipelines[key];
}
+ webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
+ ggml_webgpu_cpy_pipeline_key key = {
+ .src_type = context.src0->type,
+ .dst_type = context.dst->type,
+ };
+
+ auto it = cpy_pipelines.find(key);
+ if (it != cpy_pipelines.end()) {
+ return it->second;
+ }
+
+ std::vector<std::string> defines;
+ std::string variant = "cpy";
+
+ switch (key.src_type) {
+ case GGML_TYPE_F32:
+ defines.push_back("SRC_F32");
+ variant += "_f32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("SRC_F16");
+ variant += "_f16";
+ break;
+ default:
+ GGML_ABORT("Unsupported src type for cpy shader");
+ }
+
+ switch (key.dst_type) {
+ case GGML_TYPE_F32:
+ defines.push_back("DST_F32");
+ variant += "_f32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("DST_F16");
+ variant += "_f16";
+ break;
+ case GGML_TYPE_I32:
+ defines.push_back("DST_I32");
+ variant += "_i32";
+ break;
+ default:
+ GGML_ABORT("Unsupported dst type for cpy shader");
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ auto processed = preprocessor.preprocess(wgsl_cpy, defines);
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+ pipeline.context = decisions;
+ cpy_pipelines[key] = pipeline;
+ return cpy_pipelines[key];
+ }
+
+ webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
+ ggml_webgpu_glu_pipeline_key key = {
+ .glu_op = ggml_get_glu_op(context.dst),
+ .type = context.dst->type,
+ .split = (context.src1 != nullptr),
+ };
+
+ auto it = glu_pipelines.find(key);
+ if (it != glu_pipelines.end()) {
+ return it->second;
+ }
+
+ std::vector<std::string> defines;
+ std::string variant = "glu";
+
+ switch (key.glu_op) {
+ case GGML_GLU_OP_REGLU:
+ defines.push_back("OP_REGLU");
+ variant += "_reglu";
+ break;
+ case GGML_GLU_OP_GEGLU:
+ defines.push_back("OP_GEGLU");
+ variant += "_geglu";
+ break;
+ case GGML_GLU_OP_SWIGLU:
+ defines.push_back("OP_SWIGLU");
+ variant += "_swiglu";
+ break;
+ case GGML_GLU_OP_SWIGLU_OAI:
+ defines.push_back("OP_SWIGLU_OAI");
+ variant += "_swiglu_oai";
+ break;
+ case GGML_GLU_OP_GEGLU_ERF:
+ defines.push_back("OP_GEGLU_ERF");
+ variant += "_geglu_erf";
+ break;
+ case GGML_GLU_OP_GEGLU_QUICK:
+ defines.push_back("OP_GEGLU_QUICK");
+ variant += "_geglu_quick";
+ break;
+ default:
+ GGML_ABORT("Unsupported GLU op");
+ }
+ switch (key.type) {
+ case GGML_TYPE_F32:
+ defines.push_back("TYPE_F32");
+ variant += "_f32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("TYPE_F16");
+ variant += "_f16";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for GLU shader");
+ }
+
+ if (key.split) {
+ variant += "_split";
+ } else {
+ defines.push_back("NO_SPLIT");
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ auto processed = preprocessor.preprocess(wgsl_glu, defines);
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+ pipeline.context = decisions;
+ glu_pipelines[key] = pipeline;
+ return glu_pipelines[key];
+ }
+
+ webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
+ ggml_webgpu_rope_pipeline_key key = {
+ .type = context.dst->type,
+ .inplace = context.inplace,
+ .has_ff = (context.src2 != nullptr),
+ };
+
+ auto it = rope_pipelines.find(key);
+ if (it != rope_pipelines.end()) {
+ return it->second;
+ }
+
+ std::vector<std::string> defines;
+ std::string variant = "rope";
+
+ switch (key.type) {
+ case GGML_TYPE_F32:
+ defines.push_back("TYPE_F32");
+ variant += "_f32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("TYPE_F16");
+ variant += "_f16";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for ROPE shader");
+ }
+
+ if (key.inplace) {
+ defines.push_back("INPLACE");
+ variant += "_inplace";
+ }
+
+ if (key.has_ff) {
+ defines.push_back("FF_FUNC");
+ variant += "_ff";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ auto processed = preprocessor.preprocess(wgsl_rope, defines);
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+ pipeline.context = decisions;
+ rope_pipelines[key] = pipeline;
+ return rope_pipelines[key];
+ }
+
+ webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
+ ggml_webgpu_soft_max_pipeline_key key = {
+ .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32,
+ .has_mask = (context.src1 != nullptr),
+ .has_sink = (context.src2 != nullptr),
+ .inplace = context.inplace,
+ };
+
+ auto it = soft_max_pipelines.find(key);
+ if (it != soft_max_pipelines.end()) {
+ return it->second;
+ }
+
+ std::vector<std::string> defines;
+ std::string variant = "soft_max";
+
+ if (key.has_mask) {
+ defines.push_back("HAS_MASK");
+ switch (key.mask_type) {
+ case GGML_TYPE_F32:
+ defines.push_back("MASK_F32");
+ variant += "_mask_f32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("MASK_F16");
+ variant += "_mask_f16";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for SOFT_MAX shader");
+ }
+ }
+
+ if (key.has_sink) {
+ defines.push_back("HAS_SINK");
+ variant += "_sink";
+ }
+
+ if (key.inplace) {
+ defines.push_back("INPLACE");
+ variant += "_inplace";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+ pipeline.context = decisions;
+ soft_max_pipelines[key] = pipeline;
+ return soft_max_pipelines[key];
+ }
+
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,
wgpu::Buffer set_rows_dev_error_buf;
wgpu::Buffer set_rows_host_error_buf;
- std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
-
- std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
- std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
-
- std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
-
size_t memset_bytes_per_thread;
};
}
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
+ .src0 = src,
+ .dst = dst,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ };
+
+ webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx);
+
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = {
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
- uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
- return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
- params, entries, wg_x);
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
+ .src0 = src0,
+ .src1 = src1,
+ .src2 = src2,
+ .dst = dst,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ .inplace = ggml_webgpu_tensor_equal(src0, dst),
+ };
+
+ webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx);
+
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int has_freq_factor = (src2 != nullptr);
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
- webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
- uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
+ .src0 = src0,
+ .src1 = src1,
+ .dst = dst,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ };
+
+ webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx);
+
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
const int split = (src1 != nullptr);
std::vector<uint32_t> params = {
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
- webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
- uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
- const int inplace = ggml_webgpu_tensor_equal(src0, dst);
- const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
- const int has_sink = (src2 != nullptr);
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
+ .src0 = src0,
+ .src1 = src1,
+ .src2 = src2,
+ .dst = dst,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ .inplace = ggml_webgpu_tensor_equal(src0, dst),
+ };
+
+ webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
+
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
+ const int has_mask = (src1 != nullptr);
+ const int has_sink = (src2 != nullptr);
float max_bias;
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
- mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
+ has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
- mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
- mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
- mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
+ has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
+ has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
+ has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
- mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
- mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
+ has_mask ? (uint32_t) src1->ne[2] : 0,
+ has_mask ? (uint32_t) src1->ne[3] : 0,
*(uint32_t *) dst->op_params, // scale
*(uint32_t *) &max_bias,
*(uint32_t *) &n_head_log2,
.size = ggml_webgpu_tensor_binding_size(ctx, src0) }
};
uint32_t binding_num = 1;
- if (mask_type < 2) {
+ if (has_mask) {
entries.push_back({ .binding = binding_num,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
- return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
- ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
- ggml_nrows(dst));
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst));
}
static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
}
-static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
- webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
- webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
- webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
- webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
- webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
-}
-
-static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
- webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
- webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
- webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
- webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
-
- webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
- webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
- webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
- webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
-}
-
-static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
- // REGLU
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
-
- // GEGLU
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
-
- // SWIGLU
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
-
- // SWIGLU_OAI
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
-
- // GEGLU_ERF
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
-
- // GEGLU_QUICK
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
- webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
-}
-
-static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
-
- // f32 (no mask)
- webgpu_ctx->soft_max_pipelines[2][0][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
- webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
- webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
- webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
-
- // f32 mask (mask_type = 0)
- webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
- webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
- webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
- webgpu_ctx->soft_max_pipelines[0][1][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
- "soft_max_f32_mask_f32_sink_inplace", constants);
-
- // f16 mask (mask_type = 1)
- webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
- webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
- webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
- webgpu_ctx->soft_max_pipelines[1][1][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
- "soft_max_f32_mask_f16_sink_inplace", constants);
-}
-
static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
wgpu::RequestAdapterOptions options = {};
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
- ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
- ggml_webgpu_init_rope_pipeline(webgpu_ctx);
- ggml_webgpu_init_glu_pipeline(webgpu_ctx);
- ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
+++ /dev/null
-#define(VARIANTS)
-
-[
- {
- "REPLS": {
- "SRC_TYPE": "f32",
- "DST_TYPE": "f32"
- }
- },
- {
- "REPLS": {
- "SRC_TYPE": "f32",
- "DST_TYPE": "i32"
- }
- },
- {
- "REPLS": {
- "SRC_TYPE": "f32",
- "DST_TYPE": "f16"
- }
- },
- {
- "REPLS": {
- "SRC_TYPE": "f16",
- "DST_TYPE": "f16"
- }
- },
- {
- "REPLS": {
- "SRC_TYPE": "f16",
- "DST_TYPE": "f32"
- }
- }
-]
-
-#end(VARIANTS)
-
-#define(SHADER)
-enable f16;
-
-@group(0) @binding(0)
-var<storage, read_write> src: array<{{SRC_TYPE}}>;
-
-@group(0) @binding(1)
-var<storage, read_write> dst: array<{{DST_TYPE}}>;
-
-struct Params {
- ne: u32, // total number of elements
- offset_src: u32, // in elements
- offset_dst: u32, // in elements
-
- // Strides (in elements) — may be permuted
- stride_src0: u32,
- stride_src1: u32,
- stride_src2: u32,
- stride_src3: u32,
-
- stride_dst0: u32,
- stride_dst1: u32,
- stride_dst2: u32,
- stride_dst3: u32,
-
- // Logical shapes
- src_ne0: u32,
- src_ne1: u32,
- src_ne2: u32,
-
- dst_ne0: u32,
- dst_ne1: u32,
- dst_ne2: u32
-};
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
- if (gid.x >= params.ne) {
- return;
- }
-
- var i = gid.x;
- let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
- i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
- let i2 = i / (params.src_ne1 * params.src_ne0);
- i = i % (params.src_ne1 * params.src_ne0);
- let i1 = i / params.src_ne0;
- let i0 = i % params.src_ne0;
-
- var j = gid.x;
- let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
- j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
- let j2 = j / (params.dst_ne1 * params.dst_ne0);
- j = j % (params.dst_ne1 * params.dst_ne0);
- let j1 = j / params.dst_ne0;
- let j0 = j % params.dst_ne0;
-
- let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
- i2 * params.stride_src2 + i3 * params.stride_src3;
-
- let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
- j2 * params.stride_dst2 + j3 * params.stride_dst3;
-
- dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
-}
-#end(SHADER)
--- /dev/null
+enable f16;
+
+#ifdef SRC_F32
+#define SRC_TYPE f32
+#elif defined(SRC_F16)
+#define SRC_TYPE f16
+#endif
+
+#ifdef DST_F32
+#define DST_TYPE f32
+#elif defined(DST_F16)
+#define DST_TYPE f16
+#elif defined(DST_I32)
+#define DST_TYPE i32
+#endif
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<SRC_TYPE>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<DST_TYPE>;
+
+struct Params{
+ ne: u32,
+ offset_src: u32,
+ offset_dst: u32,
+
+ stride_src0: u32,
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+
+ stride_dst0: u32,
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ src_ne0: u32,
+ src_ne1: u32,
+ src_ne2: u32,
+
+ dst_ne0: u32,
+ dst_ne1: u32,
+ dst_ne2: u32
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= params.ne) {
+ return;
+ }
+
+ var i = gid.x;
+ let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
+ i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
+ let i2 = i / (params.src_ne1 * params.src_ne0);
+ i = i % (params.src_ne1 * params.src_ne0);
+ let i1 = i / params.src_ne0;
+ let i0 = i % params.src_ne0;
+
+ var j = gid.x;
+ let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
+ j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
+ let j2 = j / (params.dst_ne1 * params.dst_ne0);
+ j = j % (params.dst_ne1 * params.dst_ne0);
+ let j1 = j / params.dst_ne0;
+ let j0 = j % params.dst_ne0;
+
+ let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
+ i2 * params.stride_src2 + i3 * params.stride_src3;
+
+ let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
+ j2 * params.stride_dst2 + j3 * params.stride_dst3;
+
+ dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx]));
+}
+
import os
import re
-import ast
import argparse
-def extract_block(text, name):
- pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
- match = re.search(pattern, text, re.DOTALL)
- if not match:
- raise ValueError(f"Missing block: {name}")
- return match.group(1).strip()
-
-
-def parse_decls(decls_text):
- decls = {}
- for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
- decls[name.strip()] = code.strip()
- return decls
-
-
-def replace_repl_placeholders(variant, template_map):
- for repl, code in variant["REPLS"].items():
- for key, val in template_map.items():
- # Match "key" and avoid matching subsequences using by using \b
- code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
- variant["REPLS"][repl] = code
- return variant
-
-
-def replace_placeholders(shader_text, replacements):
- for key, val in replacements.items():
- # Match {{KEY}} literally, where KEY is escaped
- pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
- shader_text = re.sub(pattern, str(val), shader_text)
- return shader_text
-
-
def expand_includes(shader, input_dir):
"""
Replace #include "file" lines in the text with the contents of that file.
outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
-def generate_variants(fname, input_dir, output_dir, outfile):
- shader_path = os.path.join(input_dir, fname)
- shader_base_name = fname.split(".")[0]
-
- with open(shader_path, "r", encoding="utf-8") as f:
- text = f.read()
-
- try:
- variants = ast.literal_eval(extract_block(text, "VARIANTS"))
- except ValueError:
- write_shader(shader_base_name, text, output_dir, outfile, input_dir)
- else:
- try:
- decls_map = parse_decls(extract_block(text, "DECLS"))
- except ValueError:
- decls_map = {}
- try:
- templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
- except ValueError:
- templates_map = {}
-
- for fname in sorted(os.listdir(input_dir)):
- if fname.endswith(".tmpl"):
- tmpl_path = os.path.join(input_dir, fname)
- with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
- decls = f_tmpl.read()
- decls_map.update(parse_decls(decls))
-
- shader_template = extract_block(text, "SHADER")
- for variant in variants:
- if "DECLS" in variant:
- decls = variant["DECLS"]
- else:
- decls = []
- decls_code = ""
- for key in decls:
- if key not in decls_map:
- raise ValueError(f"DECLS key '{key}' not found.")
- decls_code += decls_map[key] + "\n\n"
- final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
- if "REPLS" in variant:
- variant = replace_repl_placeholders(variant, templates_map)
- final_shader = replace_placeholders(final_shader, variant["REPLS"])
- # second run to expand placeholders in repl_template
- final_shader = replace_placeholders(final_shader, variant["REPLS"])
- final_shader = expand_includes(final_shader, input_dir)
-
- if "SHADER_NAME" in variant:
- output_name = variant["SHADER_NAME"]
- elif "SHADER_SUFFIX" in variant:
- output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
- elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
- output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
- elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
- output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
- elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
- output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
- else:
- output_name = shader_base_name
- write_shader(output_name, final_shader, output_dir, outfile, input_dir)
-
-
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", required=True)
parser.add_argument("--output_file", required=True)
- parser.add_argument("--output_dir")
args = parser.parse_args()
- if args.output_dir:
- os.makedirs(args.output_dir, exist_ok=True)
-
with open(args.output_file, "w", encoding="utf-8") as out:
out.write("// Auto-generated shader embedding\n")
out.write("#include <string>\n\n")
for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"):
- generate_variants(fname, args.input_dir, args.output_dir, out)
+ shader_path = os.path.join(args.input_dir, fname)
+ shader_name = fname.replace(".wgsl", "")
+
+ with open(shader_path, "r", encoding="utf-8") as f:
+ shader_code = f.read()
+
+ write_shader(shader_name, shader_code, None, out, args.input_dir)
if __name__ == "__main__":
+++ /dev/null
-#define(VARIANTS)
-
-[
- {
- "SHADER_NAME": "reglu_f32",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["NO_SPLIT", "REGLU"]
- },
- {
- "SHADER_NAME": "reglu_f32_split",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["SPLIT", "REGLU"]
- },
- {
- "SHADER_NAME": "reglu_f16",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["NO_SPLIT", "REGLU"]
- },
- {
- "SHADER_NAME": "reglu_f16_split",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["SPLIT", "REGLU"]
- },
- {
- "SHADER_NAME": "geglu_f32",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["NO_SPLIT", "GEGLU"]
- },
- {
- "SHADER_NAME": "geglu_f32_split",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["SPLIT", "GEGLU"]
- },
- {
- "SHADER_NAME": "geglu_f16",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["NO_SPLIT", "GEGLU"]
- },
- {
- "SHADER_NAME": "geglu_f16_split",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["SPLIT", "GEGLU"]
- },
- {
- "SHADER_NAME": "swiglu_f32",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["NO_SPLIT", "SWIGLU"]
- },
- {
- "SHADER_NAME": "swiglu_f32_split",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["SPLIT", "SWIGLU"]
- },
- {
- "SHADER_NAME": "swiglu_f16",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["NO_SPLIT", "SWIGLU"]
- },
- {
- "SHADER_NAME": "swiglu_f16_split",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["SPLIT", "SWIGLU"]
- },
- {
- "SHADER_NAME": "swiglu_oai_f32",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
- },
- {
- "SHADER_NAME": "swiglu_oai_f32_split",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["SPLIT", "SWIGLU_OAI"]
- },
- {
- "SHADER_NAME": "geglu_erf_f32",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["NO_SPLIT", "GEGLU_ERF"]
- },
- {
- "SHADER_NAME": "geglu_erf_f32_split",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["SPLIT", "GEGLU_ERF"]
- },
- {
- "SHADER_NAME": "geglu_erf_f16",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["NO_SPLIT", "GEGLU_ERF"]
- },
- {
- "SHADER_NAME": "geglu_erf_f16_split",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["SPLIT", "GEGLU_ERF"]
- },
- {
- "SHADER_NAME": "geglu_quick_f32",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
- },
- {
- "SHADER_NAME": "geglu_quick_f32_split",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["SPLIT", "GEGLU_QUICK"]
- },
- {
- "SHADER_NAME": "geglu_quick_f16",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
- },
- {
- "SHADER_NAME": "geglu_quick_f16_split",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["SPLIT", "GEGLU_QUICK"]
- },
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(REGLU)
-fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
- return max(a, 0) * b;
-}
-#enddecl(REGLU)
-
-#decl(GEGLU)
-const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
-const GELU_COEF_A: {{TYPE}} = 0.044715;
-
-fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
- let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
- return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
-}
-#enddecl(GEGLU)
-
-#decl(SWIGLU)
-fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
- return a / (1.0 + exp(-a)) * b;
-}
-#enddecl(SWIGLU)
-
-#decl(SWIGLU_OAI)
-fn op(a: f32, b: f32) -> f32 {
- let xi = min(a, params.limit);
- let gi = max(min(b, params.limit), -params.limit);
- var out_glu = xi / (1.0 + exp(-xi * params.alpha));
- out_glu = out_glu * (1.0 + gi);
- return out_glu;
-}
-#enddecl(SWIGLU_OAI)
-
-#decl(GEGLU_ERF)
-const p_erf: {{TYPE}} = 0.3275911;
-const a1_erf: {{TYPE}} = 0.254829592;
-const a2_erf: {{TYPE}} = -0.284496736;
-const a3_erf: {{TYPE}} = 1.421413741;
-const a4_erf: {{TYPE}} = -1.453152027;
-const a5_erf: {{TYPE}} = 1.061405429;
-const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
-
-fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
- let a_div_sqr2 = a * SQRT_2_INV;
- let sign_x = sign(a_div_sqr2);
- let x = abs(a_div_sqr2);
- let t = 1.0 / (1.0 + p_erf * x);
- let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
- let erf_approx = sign_x * y;
- return 0.5 * a * (1.0 + erf_approx) * b;
-}
-#enddecl(GEGLU_ERF)
-
-#decl(GEGLU_QUICK)
-const GELU_QUICK_COEF: {{TYPE}} = -1.702;
-
-fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
- return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
-}
-#enddecl(GEGLU_QUICK)
-
-#decl(NO_SPLIT)
-@group(0) @binding(1)
-var<storage, read_write> dst: array<{{TYPE}}>;
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-fn a_value(base: u32) -> {{TYPE}} {
- let offset: u32 = select(0, params.ne0, params.swapped != 0);
- return src0[base + offset];
-}
-
-fn b_value(base: u32) -> {{TYPE}} {
- let offset: u32 = select(params.ne0, 0, params.swapped != 0);
- return src0[base + offset];
-}
-#enddecl(NO_SPLIT)
-
-#decl(SPLIT)
-@group(0) @binding(1)
-var<storage, read_write> src1: array<{{TYPE}}>;
-
-@group(0) @binding(2)
-var<storage, read_write> dst: array<{{TYPE}}>;
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-
-fn a_value(base: u32) -> {{TYPE}} {
- return src0[base];
-}
-
-fn b_value(base: u32) -> {{TYPE}} {
- return src1[base];
-}
-#enddecl(SPLIT)
-
-#end(DECLS)
-
-#define(SHADER)
-
-enable f16;
-
-struct Params {
- offset_src0: u32,
- offset_src1: u32,
- offset_dst: u32,
-
- // Strides (in elements)
- stride_src01: u32,
- stride_src02: u32,
- stride_src03: u32,
-
- stride_src11: u32,
- stride_src12: u32,
- stride_src13: u32,
-
- stride_dst1: u32,
- stride_dst2: u32,
- stride_dst3: u32,
-
- // shape of dst
- ne: u32,
- ne0: u32,
- ne1: u32,
- ne2: u32,
-
- swapped: u32,
- alpha: f32,
- limit: f32,
-}
-
-@group(0) @binding(0)
-var<storage, read_write> src0: array<{{TYPE}}>;
-
-DECLS
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
- if (gid.x >= params.ne) {
- return;
- }
-
- var i = gid.x;
- let i3 = i / (params.ne2 * params.ne1 * params.ne0);
- i = i % (params.ne2 * params.ne1 * params.ne0);
- let i2 = i / (params.ne1 * params.ne0);
- i = i % (params.ne1 * params.ne0);
- let i1 = i / params.ne0;
- let i0 = i % params.ne0;
-
- let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
- let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
- let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
-
- dst[i_dst] = op(a_value(i_a), b_value(i_b));
-}
-
-#end(SHADER)
--- /dev/null
+enable f16;
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_F16
+#define DataType f16
+#endif
+
+#ifdef OP_REGLU
+fn op(a: DataType, b: DataType) -> DataType {
+ return max(a, 0) * b;
+}
+#endif
+
+#ifdef OP_GEGLU
+const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876;
+const GELU_COEF_A: DataType = 0.044715;
+
+fn op(a: DataType, b: DataType) -> DataType {
+ let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
+ return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b;
+}
+#endif
+
+#ifdef OP_SWIGLU
+fn op(a: DataType, b: DataType) -> DataType {
+ return a / (1.0 + exp(-a)) * b;
+}
+#endif
+#ifdef OP_SWIGLU_OAI
+fn op(a: f32, b: f32) -> f32 {
+ let xi = min(a, params.limit);
+ let gi = max(min(b, params.limit), -params.limit);
+ var out_glu = xi / (1.0 + exp(-xi * params.alpha));
+ out_glu = out_glu * (1.0 + gi);
+ return out_glu;
+}
+#endif
+#ifdef OP_GEGLU_ERF
+const p_erf: DataType = 0.3275911;
+const a1_erf: DataType = 0.254829592;
+const a2_erf: DataType = -0.284496736;
+const a3_erf: DataType = 1.421413741;
+const a4_erf: DataType = -1.453152027;
+const a5_erf: DataType = 1.061405429;
+const SQRT_2_INV: DataType = 0.7071067811865476;
+
+fn op(a: DataType, b: DataType) -> DataType {
+ let a_div_sqr2 = a * SQRT_2_INV;
+ let sign_x = sign(a_div_sqr2);
+ let x = abs(a_div_sqr2);
+ let t = 1.0 / (1.0 + p_erf * x);
+ let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
+ let erf_approx = sign_x * y;
+ return 0.5 * a * (1.0 + erf_approx) * b;
+}
+#endif
+#ifdef OP_GEGLU_QUICK
+const GELU_QUICK_COEF: DataType = -1.702;
+
+fn op(a: DataType, b: DataType) -> DataType {
+ return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
+}
+#endif
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_dst: u32,
+
+ // Strides (in elements)
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src03: u32,
+
+ stride_src11: u32,
+ stride_src12: u32,
+ stride_src13: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // shape of dst
+ ne: u32,
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ swapped: u32,
+ alpha: f32,
+ limit: f32,
+}
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
+
+#ifdef NO_SPLIT
+@group(0) @binding(1)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+fn a_value(base: u32) -> DataType {
+ let offset: u32 = select(0, params.ne0, params.swapped != 0);
+ return src0[base + offset];
+}
+
+fn b_value(base: u32) -> DataType {
+ let offset: u32 = select(params.ne0, 0, params.swapped != 0);
+ return src0[base + offset];
+}
+
+#else
+@group(0) @binding(1)
+var<storage, read_write> src1: array<DataType>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+fn a_value(base: u32) -> DataType {
+ return src0[base];
+}
+
+fn b_value(base: u32) -> DataType {
+ return src1[base];
+}
+
+#endif
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= params.ne) {
+ return;
+ }
+
+ var i = gid.x;
+ let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+ i = i % (params.ne2 * params.ne1 * params.ne0);
+ let i2 = i / (params.ne1 * params.ne0);
+ i = i % (params.ne1 * params.ne0);
+ let i1 = i / params.ne0;
+ let i0 = i % params.ne0;
+
+ let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
+ let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
+ let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
+
+ dst[i_dst] = op(a_value(i_a), b_value(i_b));
+}
+++ /dev/null
-#define(VARIANTS)
-
-[
- {
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
- },
- {
- "SHADER_SUFFIX": "f32_inplace",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
- },
- {
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
- },
- {
- "SHADER_SUFFIX": "f16_inplace",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
- },
- {
- "SHADER_SUFFIX": "f32_ff",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
- },
- {
- "SHADER_SUFFIX": "f32_ff_inplace",
- "REPLS": {
- "TYPE" : "f32",
- },
- "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
- },
- {
- "SHADER_SUFFIX": "f16_ff",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
- },
- {
- "SHADER_SUFFIX": "f16_ff_inplace",
- "REPLS": {
- "TYPE" : "f16",
- },
- "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
- }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(ROTATE)
-fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
- dst[i_dst0] = {{TYPE}}(out0);
- dst[i_dst1] = {{TYPE}}(out1);
-}
-#enddecl(ROTATE)
-
-#decl(ROTATE_INPLACE)
-fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
- src0[i_dst0] = {{TYPE}}(out0);
- src0[i_dst1] = {{TYPE}}(out1);
-}
-#enddecl(ROTATE_INPLACE)
-
-#decl(NO_FF_FUNC)
-fn freq_factor(i: u32) -> f32 {
- return 1.0f;
-}
-#enddecl(NO_FF_FUNC)
-
-#decl(FF_FUNC)
-fn freq_factor(i: u32) -> f32 {
- return src2[params.offset_src2 + i/2];
-}
-#enddecl(FF_FUNC)
-
-#decl(NO_FF_BINDINGS)
-
-@group(0) @binding(2)
-var<storage, read_write> dst: array<{{TYPE}}>;
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-
-#enddecl(NO_FF_BINDINGS)
-
-#decl(NO_FF_BINDINGS_INPLACE)
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-#enddecl(NO_FF_BINDINGS_INPLACE)
-
-#decl(FF_BINDINGS)
-
-@group(0) @binding(2)
-var<storage, read_write> src2: array<f32>;
-
-@group(0) @binding(3)
-var<storage, read_write> dst: array<{{TYPE}}>;
-
-@group(0) @binding(4)
-var<uniform> params: Params;
-
-#enddecl(FF_BINDINGS)
-
-#decl(FF_BINDINGS_INPLACE)
-
-@group(0) @binding(2)
-var<storage, read_write> src2: array<f32>;
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-
-#enddecl(FF_BINDINGS_INPLACE)
-
-#end(DECLS)
-
-#define(SHADER)
-
-enable f16;
-
-struct Params {
- offset_src0: u32,
- offset_src1: u32,
- offset_src2: u32,
- offset_dst: u32,
-
- // Strides (in elements)
- stride_src01: u32,
- stride_src02: u32,
- stride_src03: u32,
-
- stride_dst1: u32,
- stride_dst2: u32,
- stride_dst3: u32,
-
- n_threads: u32,
- ne0: u32,
- ne1: u32,
- ne2: u32,
-
- n_dims: u32,
- mode: u32,
- theta_scale: f32,
- attn_factor: f32,
- freq_scale: f32,
- ext_factor: f32,
- corr_dim0: f32,
- corr_dim1: f32,
- sections0: u32,
- sections1: u32,
- sections2: u32,
- sections3: u32
-};
-
-@group(0) @binding(0)
-var<storage, read_write> src0: array<{{TYPE}}>;
-
-@group(0) @binding(1)
-var<storage, read_write> src1: array<i32>;
-
-DECLS
-
-fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
- let y = (f32(i / 2) - low) / max(0.001f, high - low);
- return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-// returns vector of (cos_theta, sin_theta)
-// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
-fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
- var mscale = params.attn_factor;
- var theta = params.freq_scale * theta_extrap;
- if (params.ext_factor != 0.0f) {
- let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
- theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
- mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
- }
- return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
-}
-
-fn pair_base(i0: u32, div_2: bool) -> u32 {
- if (div_2) {
- return i0 / 2;
- } else {
- return i0;
- }
-}
-
-fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
- if (is_vision) {
- return params.n_dims;
- } else if (is_neox || is_mrope) {
- return params.n_dims / 2;
- } else {
- return 1;
- }
-}
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
- // two elements per thread
- if (gid.x >= params.n_threads) {
- return;
- }
-
- let is_neox = bool(params.mode & 2);
- let is_mrope = bool(params.mode & 8);
- let is_imrope = params.mode == 40;
- let is_vision = params.mode == 24;
-
- var i = gid.x * 2; // start index for this thread
- let i3 = i / (params.ne2 * params.ne1 * params.ne0);
- i = i % (params.ne2 * params.ne1 * params.ne0);
- let i2 = i / (params.ne1 * params.ne0);
- i = i % (params.ne1 * params.ne0);
- let i1 = i / params.ne0;
- let i0 = i % params.ne0;
-
- let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
- let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
-
- if (i0 >= params.n_dims && !is_vision) {
- let i_src = i_src_row + i0;
- let i_dst = i_dst_row + i0;
- rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
- return;
- }
-
- var theta_base_mult: u32 = 0;
- var theta_scale_pwr: u32 = i0 / 2;
- if (is_mrope) {
- let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
- let sec_w = params.sections1 + params.sections0;
- let sec_e = params.sections2 + sec_w;
- let sector = (i0 / 2) % sect_dims;
- if (is_imrope) {
- if (sector % 3 == 1 && sector < 3 * params.sections1) {
- theta_base_mult = 1;
- } else if (sector % 3 == 2 && sector < 3 * params.sections2) {
- theta_base_mult = 2;
- } else if (sector % 3 == 0 && sector < 3 * params.sections0) {
- theta_base_mult = 0;
- } else {
- theta_base_mult = 3;
- }
- } else {
- if (sector >= params.sections0 && sector < sec_w) {
- theta_base_mult = 1;
- if (is_vision) {
- theta_scale_pwr = sector - params.sections0;
- }
- } else if (sector >= sec_w && sector < sec_e) {
- theta_base_mult = 2;
- if (is_vision) {
- theta_scale_pwr = sector - sec_w;
- }
- } else if (sector >= sec_e) {
- if (is_vision) {
- theta_scale_pwr = sector - sec_e;
- theta_scale_pwr = (i0 / 2) % sec_e;
- }
- theta_base_mult = 3;
- } else if (is_vision) {
- theta_scale_pwr = sector;
- }
- }
- }
- let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
- let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
-
- let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
- let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
-
- let x0 = f32(src0[i_src]);
- let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
- rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
-}
-
-#end(SHADER)
--- /dev/null
+enable f16;
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_F16
+#define DataType f16
+#endif
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_src2: u32,
+ offset_dst: u32,
+
+ // Strides (in elements)
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src03: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ n_threads: u32,
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ n_dims: u32,
+ mode: u32,
+ theta_scale: f32,
+ attn_factor: f32,
+ freq_scale: f32,
+ ext_factor: f32,
+ corr_dim0: f32,
+ corr_dim1: f32,
+ sections0: u32,
+ sections1: u32,
+ sections2: u32,
+ sections3: u32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
+@group(0) @binding(1)
+var<storage, read_write> src1: array<i32>;
+
+#ifdef INPLACE
+
+#ifdef FF_FUNC
+
+@group(0) @binding(2)
+var<storage, read_write> src2: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#else
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#endif
+
+#else
+
+#ifdef FF_FUNC
+@group(0) @binding(2)
+var<storage, read_write> src2: array<f32>;
+
+@group(0) @binding(3)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(4)
+var<uniform> params: Params;
+
+#else
+@group(0) @binding(2)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#endif
+#endif
+
+#ifdef FF_FUNC
+fn freq_factor(i: u32) -> f32 {
+ return src2[params.offset_src2 + i/2];
+}
+
+#else
+fn freq_factor(i: u32) -> f32 {
+ return 1.0f;
+}
+#endif
+#ifdef INPLACE
+fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
+ src0[i_dst0] = DataType(out0);
+ src0[i_dst1] = DataType(out1);
+}
+#else
+fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
+ dst[i_dst0] = DataType(out0);
+ dst[i_dst1] = DataType(out1);
+}
+#endif
+
+fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
+ let y = (f32(i / 2) - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// returns vector of (cos_theta, sin_theta)
+// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
+fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
+ var mscale = params.attn_factor;
+ var theta = params.freq_scale * theta_extrap;
+ if (params.ext_factor != 0.0f) {
+ let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
+ theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
+ mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
+ }
+ return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
+}
+
+fn pair_base(i0: u32, div_2: bool) -> u32 {
+ if (div_2) {
+ return i0 / 2;
+ } else {
+ return i0;
+ }
+}
+
+fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
+ if (is_vision) {
+ return params.n_dims;
+ } else if (is_neox || is_mrope) {
+ return params.n_dims / 2;
+ } else {
+ return 1;
+ }
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ // two elements per n_threads
+ if (gid.x >= params.n_threads) {
+ return;
+ }
+
+ let is_neox = bool(params.mode & 2);
+ let is_mrope = bool(params.mode & 8);
+ let is_imrope = params.mode == 40;
+ let is_vision = params.mode == 24;
+
+ var i = gid.x * 2; // start index for this thread
+ let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+ i = i % (params.ne2 * params.ne1 * params.ne0);
+ let i2 = i / (params.ne1 * params.ne0);
+ i = i % (params.ne1 * params.ne0);
+ let i1 = i / params.ne0;
+ let i0 = i % params.ne0;
+
+ let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+
+ if (i0 >= params.n_dims && !is_vision) {
+ let i_src = i_src_row + i0;
+ let i_dst = i_dst_row + i0;
+ rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
+ return;
+ }
+
+ var theta_base_mult: u32 = 0;
+ var theta_scale_pwr: u32 = i0 / 2;
+ if (is_mrope) {
+ let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
+ let sec_w = params.sections1 + params.sections0;
+ let sec_e = params.sections2 + sec_w;
+ let sector = (i0 / 2) % sect_dims;
+ if (is_imrope) {
+ if (sector % 3 == 1 && sector < 3 * params.sections1) {
+ theta_base_mult = 1;
+ } else if (sector % 3 == 2 && sector < 3 * params.sections2) {
+ theta_base_mult = 2;
+ } else if (sector % 3 == 0 && sector < 3 * params.sections0) {
+ theta_base_mult = 0;
+ } else {
+ theta_base_mult = 3;
+ }
+ } else {
+ if (sector >= params.sections0 && sector < sec_w) {
+ theta_base_mult = 1;
+ if (is_vision) {
+ theta_scale_pwr = sector - params.sections0;
+ }
+ } else if (sector >= sec_w && sector < sec_e) {
+ theta_base_mult = 2;
+ if (is_vision) {
+ theta_scale_pwr = sector - sec_w;
+ }
+ } else if (sector >= sec_e) {
+ if (is_vision) {
+ theta_scale_pwr = sector - sec_e;
+ theta_scale_pwr = (i0 / 2) % sec_e;
+ }
+ theta_base_mult = 3;
+ } else if (is_vision) {
+ theta_scale_pwr = sector;
+ }
+ }
+ }
+ let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
+ let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
+
+ let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
+ let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
+
+ let x0 = f32(src0[i_src]);
+ let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
+ rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
+
+}
+++ /dev/null
-#define(VARIANTS)
-[
- {
- "SHADER_NAME": "soft_max_f32",
- "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_inplace",
- "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_sink",
- "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_sink_inplace",
- "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_mask_f32",
- "REPLS": {
- "MASK_TYPE" : "f32",
- },
- "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_mask_f32_inplace",
- "REPLS": {
- "MASK_TYPE" : "f32",
- },
- "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_mask_f16",
- "REPLS": {
- "MASK_TYPE" : "f16",
- },
- "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_mask_f16_inplace",
- "REPLS": {
- "MASK_TYPE" : "f16",
- },
- "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_mask_f32_sink",
- "REPLS": {
- "MASK_TYPE" : "f32",
- },
- "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
- "REPLS": {
- "MASK_TYPE" : "f32",
- },
- "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_mask_f16_sink",
- "REPLS": {
- "MASK_TYPE" : "f16",
- },
- "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
- },
- {
- "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
- "REPLS": {
- "MASK_TYPE" : "f16",
- },
- "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
- }
-]
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(BASE_BINDINGS)
-@group(0) @binding(1)
-var<storage, read_write> dst: array<f32>;
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-#enddecl(BASE_BINDINGS)
-
-#decl(BASE_BINDINGS_INPLACE)
-@group(0) @binding(1)
-var<uniform> params: Params;
-#enddecl(BASE_BINDINGS_INPLACE)
-
-#decl(SINK_BINDINGS)
-@group(0) @binding(1)
-var<storage, read_write> sinks: array<f32>;
-
-@group(0) @binding(2)
-var<storage, read_write> dst: array<f32>;
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-#enddecl(SINK_BINDINGS)
-
-#decl(SINK_BINDINGS_INPLACE)
-@group(0) @binding(1)
-var<storage, read_write> sinks: array<f32>;
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-#enddecl(SINK_BINDINGS_INPLACE)
-
-#decl(MASK_BINDINGS)
-@group(0) @binding(1)
-var<storage, read_write> mask: array<{{MASK_TYPE}}>;
-
-@group(0) @binding(2)
-var<storage, read_write> dst: array<f32>;
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-#enddecl(MASK_BINDINGS)
-
-#decl(MASK_BINDINGS_INPLACE)
-@group(0) @binding(1)
-var<storage, read_write> mask: array<{{MASK_TYPE}}>;
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-#enddecl(MASK_BINDINGS_INPLACE)
-
-#decl(MASK_SINK_BINDINGS)
-@group(0) @binding(1)
-var<storage, read_write> mask: array<{{MASK_TYPE}}>;
-
-@group(0) @binding(2)
-var<storage, read_write> sinks: array<f32>;
-
-@group(0) @binding(3)
-var<storage, read_write> dst: array<f32>;
-
-@group(0) @binding(4)
-var<uniform> params: Params;
-#enddecl(MASK_SINK_BINDINGS)
-
-#decl(MASK_SINK_BINDINGS_INPLACE)
-@group(0) @binding(1)
-var<storage, read_write> mask: array<{{MASK_TYPE}}>;
-
-@group(0) @binding(2)
-var<storage, read_write> sinks: array<f32>;
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-#enddecl(MASK_SINK_BINDINGS_INPLACE)
-
-#decl(NOT_INPLACE)
-fn inter_value(i: u32) -> f32 {
- return dst[i];
-}
-
-fn update(i: u32, val: f32) {
- dst[i] = val;
-}
-#enddecl(NOT_INPLACE)
-
-#decl(INPLACE)
-fn inter_value(i: u32) -> f32 {
- return src[i];
-}
-
-fn update(i: u32, val: f32) {
- src[i] = val;
-}
-#enddecl(INPLACE)
-
-#decl(NO_MASK)
-fn mask_val(i: u32) -> f32 {
- return 0.0;
-}
-#enddecl(NO_MASK)
-
-#decl(MASK)
-fn mask_val(i: u32) -> f32 {
- return f32(mask[i]);
-}
-#enddecl(MASK)
-
-#decl(NO_SINK)
-fn lower_max_bound(i2: u32) -> f32 {
- return -1e30;
-}
-
-fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
- return val;
-}
-#enddecl(NO_SINK)
-
-#decl(SINK)
-fn lower_max_bound(i2: u32) -> f32 {
- return sinks[params.offset_sinks + i2];
-}
-
-fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
- return val + exp(sinks[params.offset_sinks + i2] - max_val);
-}
-#enddecl(SINK)
-
-#end(DECLS)
-
-#define(SHADER)
-enable f16;
-
-struct Params {
- offset_src0: u32,
- offset_src1: u32,
- offset_sinks: u32,
- offset_dst: u32,
-
- // Strides (in elements)
- stride_src01: u32,
- stride_src02: u32,
- stride_src03: u32,
-
- stride_src11: u32,
- stride_src12: u32,
- stride_src13: u32,
-
- stride_dst1: u32,
- stride_dst2: u32,
- stride_dst3: u32,
-
- // shape of src0/dst
- ne: u32,
- ne0: u32,
- ne1: u32,
- ne2: u32,
-
- // shape of src1
- ne12: u32,
- ne13: u32,
-
- scale: f32,
- max_bias: f32,
- n_head_log2: f32,
- m0: f32,
- m1: f32,
-};
-
-@group(0) @binding(0)
-var<storage, read_write> src: array<f32>;
-
-DECLS
-
-const CACHE_SIZE: u32 = 16;
-
-override wg_size: u32;
-var<workgroup> scratch: array<f32, wg_size>;
-
-@compute @workgroup_size(wg_size)
-fn main(@builtin(workgroup_id) wid: vec3<u32>,
- @builtin(local_invocation_id) lid: vec3<u32>) {
-
- var i = wid.x;
- let i3 = i / (params.ne2 * params.ne1);
- i = i % (params.ne2 * params.ne1);
- let i2 = i / params.ne1;
- let i1 = i % params.ne1;
- let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
- let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
- let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
- let elems = (params.ne0 + wg_size - 1) / wg_size;
-
- let head = f32(i2);
- let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
-
- var cache: array<f32, CACHE_SIZE>;
-
- var max_val = lower_max_bound(i2);
- var col = lid.x;
- for (var j: u32 = 0; j < elems; j++) {
- if (col >= params.ne0) {
- break;
- }
- let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
- max_val = max(max_val, val);
- if (col < CACHE_SIZE) {
- cache[col] = val;
- }
- col += wg_size;
- }
-
- scratch[lid.x] = max_val;
- workgroupBarrier();
- var offset = wg_size / 2;
- while (offset > 0) {
- if (lid.x < offset) {
- scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
- }
- offset = offset / 2;
- workgroupBarrier();
- }
- let row_max = scratch[0];
- workgroupBarrier();
-
- var sum = 0.0f;
- col = lid.x;
- for (var j: u32 = 0; j < elems; j++) {
- if (col >= params.ne0) {
- break;
- }
- let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
- cache[col], col < CACHE_SIZE);
- let ex = exp(val - row_max);
- sum += ex;
- if (col < CACHE_SIZE) {
- cache[col] = ex;
- } else {
- update(i_dst_row + col, ex);
- }
- col += wg_size;
- }
-
- scratch[lid.x] = sum;
- workgroupBarrier();
- offset = wg_size / 2;
- while (offset > 0) {
- if (lid.x < offset) {
- scratch[lid.x] += scratch[lid.x + offset];
- }
- offset = offset / 2;
- workgroupBarrier();
- }
- let row_sum = add_sinks(scratch[0], i2, row_max);
-
- let sum_recip = 1.0 / row_sum;
- col = lid.x;
- for (var j: u32 = 0; j < elems; j++) {
- if (col >= params.ne0) {
- break;
- }
- update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
- col += wg_size;
- }
-}
-#end(SHADER)
--- /dev/null
+enable f16;
+
+#ifdef MASK_F32
+#define MaskType f32
+#endif
+#ifdef MASK_F16
+#define MaskType f16
+#endif
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_sinks: u32,
+ offset_dst: u32,
+
+ // Strides (in elements)
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src03: u32,
+
+ stride_src11: u32,
+ stride_src12: u32,
+ stride_src13: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // shape of src0/dst
+ ne: u32,
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ // shape of src1
+ ne12: u32,
+ ne13: u32,
+
+ scale: f32,
+ max_bias: f32,
+ n_head_log2: f32,
+ m0: f32,
+ m1: f32,
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+#ifdef HAS_MASK
+#ifdef HAS_SINK
+@group(0) @binding(1)
+var<storage, read_write> mask: array<MaskType>;
+@group(0) @binding(2)
+var<storage, read_write> sinks: array<f32>;
+
+#ifdef INPLACE
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#else
+@group(0) @binding(3)
+var<storage, read_write> dst: array<f32>;
+@group(0) @binding(4)
+var<uniform> params: Params;
+#endif
+
+#else
+@group(0) @binding(1)
+var<storage, read_write> mask: array<MaskType>;
+
+#ifdef INPLACE
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#else
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+@group(0) @binding(3)
+var<uniform> params: Params;
+#endif
+#endif
+
+#else
+#ifdef HAS_SINK
+@group(0) @binding(1)
+var<storage, read_write> sinks: array<f32>;
+
+#ifdef INPLACE
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#else
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+@group(0) @binding(3)
+var<uniform> params: Params;
+#endif
+
+#else
+#ifdef INPLACE
+@group(0) @binding(1)
+var<uniform> params: Params;
+#else
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+@group(0) @binding(2)
+var<uniform> params: Params;
+#endif
+#endif
+#endif
+
+#ifdef INPLACE
+fn inter_value(i: u32) -> f32 {
+ return src[i];
+}
+fn update(i: u32, val: f32) {
+ src[i] = val;
+}
+
+#else
+fn inter_value(i: u32) -> f32 {
+ return dst[i];
+}
+fn update(i: u32, val: f32) {
+ dst[i] = val;
+}
+#endif
+
+#ifdef HAS_MASK
+fn mask_val(i: u32) -> f32 {
+ return f32(mask[i]);
+}
+
+#else
+fn mask_val(i: u32) -> f32 {
+ return 0.0;
+}
+#endif
+
+#ifdef HAS_SINK
+fn lower_max_bound(i2: u32) -> f32 {
+ return sinks[params.offset_sinks + i2];
+}
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+ return val + exp(sinks[params.offset_sinks + i2] - max_val);
+}
+#else
+fn lower_max_bound(i2: u32) -> f32 {
+ return -1e30;
+}
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+ return val;
+}
+#endif
+
+const CACHE_SIZE: u32 = 16;
+var<workgroup> scratch: array<f32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+
+ var i = wid.x;
+ let i3 = i / (params.ne2 * params.ne1);
+ i = i % (params.ne2 * params.ne1);
+ let i2 = i / params.ne1;
+ let i1 = i % params.ne1;
+ let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+ let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+ let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
+
+ let head = f32(i2);
+ let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
+
+ var cache: array<f32, CACHE_SIZE>;
+
+ var max_val = lower_max_bound(i2);
+ var col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
+ max_val = max(max_val, val);
+ if (col < CACHE_SIZE) {
+ cache[col] = val;
+ }
+ col += WG_SIZE;
+ }
+
+ scratch[lid.x] = max_val;
+ workgroupBarrier();
+ var offset: u32 = WG_SIZE / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ let row_max = scratch[0];
+ workgroupBarrier();
+
+ var sum = 0.0f;
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
+ cache[col], col < CACHE_SIZE);
+ let ex = exp(val - row_max);
+ sum += ex;
+ if (col < CACHE_SIZE) {
+ cache[col] = ex;
+ } else {
+ update(i_dst_row + col, ex);
+ }
+ col += WG_SIZE;
+ }
+
+ scratch[lid.x] = sum;
+ workgroupBarrier();
+ offset = WG_SIZE / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] += scratch[lid.x + offset];
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ let row_sum = add_sinks(scratch[0], i2, row_max);
+
+ let sum_recip = 1.0 / row_sum;
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
+ col += WG_SIZE;
+ }
+}
+