/* Constants */
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
+#define WEBGPU_WAIT_ANY_BATCH_SIZE 64
#define WEBGPU_MUL_MAT_WG_SIZE 64
#define WEBGPU_NUM_PARAM_BUFS 100
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
+// For operations which process a row in parallel, this seems like a reasonable default
+#define WEBGPU_ROW_SPLIT_WG_SIZE 64
+
/* End Constants */
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline get_rows_pipeline[30];
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
- wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
- wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
- wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
- wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
- wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
- wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
- wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
- wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
- wgpu::ComputePipeline scale_pipeline[2]; // inplace
+ wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
+ wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
+ wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
+ wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
+ wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
+ wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
+ wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
+ wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
+ wgpu::ComputePipeline scale_pipeline[2]; // inplace
+ wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace
size_t memset_bytes_per_thread;
}),
UINT64_MAX);
} else {
- // existing callbacks, wait on them
- ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
+ // WebGPU implementations may limit the number of futures that can be waited on at once,
+ // so wait in batches (64 is what Dawn supports).
+ for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) {
+ size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size());
+ ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX);
+ }
ctx->callback_futures.clear();
}
}
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
- size_t max_wg_size = ctx->max_wg_size_x;
- uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
- ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src),
ggml_op_name(dst->op));
}
ggml_op_name(dst->op));
}
+static void ggml_webgpu_soft_max(webgpu_context & ctx,
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * src2,
+ ggml_tensor * dst) {
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
+ const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
+ const int has_sink = (src2 != nullptr);
+ float max_bias;
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+ mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
+ has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+ mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
+ mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
+ mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+ (uint32_t) ggml_nelements(dst),
+ (uint32_t) src0->ne[0],
+ (uint32_t) src0->ne[1],
+ (uint32_t) src0->ne[2],
+ mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
+ mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
+ *(uint32_t *) dst->op_params, // scale
+ *(uint32_t *) &max_bias,
+ *(uint32_t *) &n_head_log2,
+ *(uint32_t *) &m0,
+ *(uint32_t *) &m1
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) }
+ };
+ uint32_t binding_num = 1;
+ if (mask_type < 2) {
+ entries.push_back({ .binding = binding_num,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
+ binding_num++;
+ }
+ if (has_sink) {
+ entries.push_back({ .binding = binding_num,
+ .buffer = ggml_webgpu_tensor_buf(src2),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
+ binding_num++;
+ }
+ if (!inplace) {
+ entries.push_back({ .binding = binding_num,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+ }
+
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries,
+ ggml_nrows(dst), ggml_op_name(dst->op));
+}
+
// Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
if (ggml_is_empty(node)) {
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
}
-// The max workgroup size is a common constant
-static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
+// Workgroup size is a common constant
+static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
- constants[0].value = webgpu_ctx->max_wg_size_x;
+ constants[0].value = wg_size;
return constants;
}
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
- ggml_webgpu_max_wg_size_entry(webgpu_ctx));
+ ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
}
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
"get_rows_f32_vec", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
}
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
}
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
}
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
"rope_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
}
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
// reglu
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
wgsl_reglu_f32, "reglu_f32", constants);
}
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
"scale_f32_inplace", constants);
}
+static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32,
+ "soft_max_f32", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace,
+ "soft_max_f32_inplace", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][0], wgsl_soft_max_f32_sink,
+ "soft_max_f32_sink", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][1],
+ wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][0], wgsl_soft_max_f32_mask_f32,
+ "soft_max_f32_mask_f32", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][1],
+ wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][0], wgsl_soft_max_f32_mask_f16,
+ "soft_max_f32_mask_f16", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][1],
+ wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][0],
+ wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][1],
+ wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace",
+ constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][0],
+ wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][1],
+ wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace",
+ constants);
+}
+
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
ggml_tensor * src0 = op->src[0];
ggml_tensor * src1 = op->src[1];
+ ggml_tensor * src2 = op->src[2];
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
break;
case GGML_OP_SET_ROWS:
- supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
+ supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
break;
case GGML_OP_GET_ROWS:
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
default:
break;
}
-#ifdef GGML_WEBGPU_DEBUG
+ if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
+ (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
+ (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
+ (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
+ supports_op = false;
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
+ }
+
if (!supports_op) {
- WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
- << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
- << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
+ } else {
+ WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
}
-#endif
return supports_op;
}
--- /dev/null
+#define(VARIANTS)
+[
+ {
+ "SHADER_NAME": "soft_max_f32",
+ "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_inplace",
+ "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_sink",
+ "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_sink_inplace",
+ "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32_sink",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16_sink",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
+ }
+]
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(BASE_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#enddecl(BASE_BINDINGS)
+
+#decl(BASE_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<uniform> params: Params;
+#enddecl(BASE_BINDINGS_INPLACE)
+
+#decl(SINK_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#enddecl(SINK_BINDINGS)
+
+#decl(SINK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#enddecl(SINK_BINDINGS_INPLACE)
+
+#decl(MASK_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#enddecl(MASK_BINDINGS)
+
+#decl(MASK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#enddecl(MASK_BINDINGS_INPLACE)
+
+#decl(MASK_SINK_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(3)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(4)
+var<uniform> params: Params;
+#enddecl(MASK_SINK_BINDINGS)
+
+#decl(MASK_SINK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#enddecl(MASK_SINK_BINDINGS_INPLACE)
+
+#decl(NOT_INPLACE)
+fn inter_value(i: u32) -> f32 {
+ return dst[i];
+}
+
+fn update(i: u32, val: f32) {
+ dst[i] = val;
+}
+#enddecl(NOT_INPLACE)
+
+#decl(INPLACE)
+fn inter_value(i: u32) -> f32 {
+ return src[i];
+}
+
+fn update(i: u32, val: f32) {
+ src[i] = val;
+}
+#enddecl(INPLACE)
+
+#decl(NO_MASK)
+fn mask_val(i: u32) -> f32 {
+ return 0.0;
+}
+#enddecl(NO_MASK)
+
+#decl(MASK)
+fn mask_val(i: u32) -> f32 {
+ return f32(mask[i]);
+}
+#enddecl(MASK)
+
+#decl(NO_SINK)
+fn lower_max_bound(i2: u32) -> f32 {
+ return -1e30;
+}
+
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+ return val;
+}
+#enddecl(NO_SINK)
+
+#decl(SINK)
+fn lower_max_bound(i2: u32) -> f32 {
+ return sinks[params.offset_sinks + i2];
+}
+
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+ return val + exp(sinks[params.offset_sinks + i2] - max_val);
+}
+#enddecl(SINK)
+
+#end(DECLS)
+
+#define(SHADER)
+enable f16;
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_sinks: u32,
+ offset_dst: u32,
+
+ // Strides (in elements)
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src03: u32,
+
+ stride_src11: u32,
+ stride_src12: u32,
+ stride_src13: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // shape of src0/dst
+ ne: u32,
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ // shape of src1
+ ne12: u32,
+ ne13: u32,
+
+ scale: f32,
+ max_bias: f32,
+ n_head_log2: f32,
+ m0: f32,
+ m1: f32,
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+DECLS
+
+const CACHE_SIZE: u32 = 16;
+
+override wg_size: u32;
+var<workgroup> scratch: array<f32, wg_size>;
+
+@compute @workgroup_size(wg_size)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+
+ var i = wid.x;
+ let i3 = i / (params.ne2 * params.ne1);
+ i = i % (params.ne2 * params.ne1);
+ let i2 = i / params.ne1;
+ let i1 = i % params.ne1;
+ let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+ let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+ let elems = (params.ne0 + wg_size - 1) / wg_size;
+
+ let head = f32(i2);
+ let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
+
+ var cache: array<f32, CACHE_SIZE>;
+
+ var max_val = lower_max_bound(i2);
+ var col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
+ max_val = max(max_val, val);
+ if (col < CACHE_SIZE) {
+ cache[col] = val;
+ }
+ col += wg_size;
+ }
+
+ scratch[lid.x] = max_val;
+ workgroupBarrier();
+ var offset = wg_size / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ let row_max = scratch[0];
+
+ var sum = 0.0f;
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
+ cache[col], col < CACHE_SIZE);
+ let ex = exp(val - row_max);
+ sum += ex;
+ if (col < CACHE_SIZE) {
+ cache[col] = ex;
+ } else {
+ update(i_dst_row + col, ex);
+ }
+ col += wg_size;
+ }
+
+ scratch[lid.x] = sum;
+ workgroupBarrier();
+ offset = wg_size / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] += scratch[lid.x + offset];
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ let row_sum = add_sinks(scratch[0], i2, row_max);
+
+ let sum_recip = 1.0 / row_sum;
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
+ col += wg_size;
+ }
+}
+#end(SHADER)