size_t wg_mem_limit_bytes = 0;
bool inplace = false;
bool overlap = false;
+ bool src_overlap = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
uint32_t sg_mat_n = 0;
int op;
bool inplace;
bool overlap;
+ bool src_overlap;
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
- return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
+ return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
}
};
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.overlap);
+ ggml_webgpu_hash_combine(seed, key.src_overlap);
return seed;
}
};
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
+ .src_overlap = context.src_overlap,
};
auto it = binary_pipelines.find(key);
} else if (key.overlap) {
defines.push_back("OVERLAP");
variant += "_overlap";
+ } else if (key.src_overlap) {
+ defines.push_back("SRC_OVERLAP");
+ variant += "_src_overlap";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
struct binary_overlap_flags {
bool inplace; // src0 == dst
bool overlap; // src1 == dst
+ bool src_overlap;
};
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
binary_overlap_flags flags = {};
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
+ flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
return flags;
}
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = flags.inplace,
.overlap = flags.overlap,
+ .src_overlap = flags.src_overlap,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
uint32_t ne = (uint32_t) ggml_nelements(dst);
+ size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
+ size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
+
+ uint32_t offset_merged_src0 = 0;
+ uint32_t offset_merged_src1 = 0;
+ if (flags.src_overlap) {
+ size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
+ offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
+ offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
+ }
+
std::vector<uint32_t> params = {
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
- (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ offset_merged_src0,
+ offset_merged_src1,
+ (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
std::vector<wgpu::BindGroupEntry> entries;
- entries.push_back({
- .binding = 0,
- .buffer = ggml_webgpu_tensor_buf(src0),
- .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
- .size = ggml_webgpu_tensor_binding_size(ctx, src0),
- });
-
- entries.push_back({
- .binding = 1,
- .buffer = ggml_webgpu_tensor_buf(src1),
- .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
- .size = ggml_webgpu_tensor_binding_size(ctx, src1),
- });
-
- if (!flags.inplace && !flags.overlap) {
- entries.push_back({ .binding = 2,
- .buffer = ggml_webgpu_tensor_buf(dst),
- .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
- .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+ if (flags.src_overlap) {
+ size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
+ size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
+ src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
+ entries.push_back({
+ .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = merged_offset,
+ .size = merged_end - merged_offset,
+ });
+ entries.push_back({
+ .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst),
+ });
+ } else {
+ entries.push_back({
+ .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = src0_webgpu_tensor_align_offset,
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0),
+ });
+ entries.push_back({
+ .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = src1_webgpu_tensor_align_offset,
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1),
+ });
+ if (!flags.inplace && !flags.overlap) {
+ entries.push_back({
+ .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst),
+ });
+ }
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
- // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
- // see https://github.com/ggml-org/llama.cpp/pull/16857
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
- (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
+ (src1->type == op->type);
break;
case GGML_OP_CPY:
case GGML_OP_CONT:
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
+ offset_merged_src0: u32,
+ offset_merged_src1: u32,
+
+ stride_src0_0: u32,
+ stride_src0_1: u32,
+ stride_src0_2: u32,
+ stride_src0_3: u32,
stride_src1_0: u32,
stride_src1_1: u32,
b_ne3: u32,
};
+fn src0_index(_i: u32) -> u32 {
+ var i = _i;
+ let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
+ i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
+ let a_i2 = i / (params.a_ne1 * params.a_ne0);
+ i = i % (params.a_ne1 * params.a_ne0);
+ let a_i1 = i / params.a_ne0;
+ let a_i0 = i % params.a_ne0;
+
+ return a_i0 * params.stride_src0_0 +
+ a_i1 * params.stride_src0_1 +
+ a_i2 * params.stride_src0_2 +
+ a_i3 * params.stride_src0_3;
+}
+
fn src1_index(_i: u32) -> u32 {
var i = _i;
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
#define DataType f16
#endif
+#ifdef SRC_OVERLAP
@group(0) @binding(0)
-var<storage, read_write> src0: array<DataType>;
+var<storage, read_write> merged_src: array<DataType>;
@group(0) @binding(1)
-var<storage, read_write> src1 : array<DataType>;
+var<storage, read_write> dst: array<DataType>;
-#ifdef INPLACE
@group(0) @binding(2)
var<uniform> params: Params;
+#else
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
-#elif defined(OVERLAP)
+@group(0) @binding(1)
+var<storage, read_write> src1 : array<DataType>;
+#if defined(INPLACE) || defined(OVERLAP)
@group(0) @binding(2)
var<uniform> params: Params;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
+#endif
fn op(a: DataType, b: DataType) -> DataType {
#ifdef OP_ADD
#endif
}
-fn update(dst_i: u32, src0_i: u32, src1_i: u32){
+fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
+#ifdef SRC_OVERLAP
+ let result = op(merged_src[src0_i], merged_src[src1_i]);
+#else
let result = op(src0[src0_i], src1[src1_i]);
+#endif
#ifdef INPLACE
- src0[dst_i] = result;
+ src0[src0_i] = result;
#elif defined(OVERLAP)
- src1[dst_i] = result;
+ src1[src1_i] = result;
#else
dst[dst_i] = result;
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
- update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
+ let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x);
+ let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x);
+ update(params.offset_dst + gid.x, src0_i, src1_i);
}
}