ctx->staged_param_bufs.push_back(params_bufs);
if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
ggml_backend_webgpu_submit_queue(ctx);
+ ggml_backend_webgpu_wait_on_submission(ctx);
}
}
}
case GGML_OP_SCALE:
ggml_webgpu_scale(ctx, src0, node);
break;
+ case GGML_OP_SOFT_MAX:
+ ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
+ break;
default:
return false;
}
case GGML_OP_SCALE:
supports_op = op->type == GGML_TYPE_F32;
break;
+ case GGML_OP_SOFT_MAX:
+ supports_op = op->type == GGML_TYPE_F32;
+ break;
default:
break;
}
ggml_webgpu_init_rope_pipeline(ctx);
ggml_webgpu_init_glu_pipeline(ctx);
ggml_webgpu_init_scale_pipeline(ctx);
+ ggml_webgpu_init_soft_max_pipeline(ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
- let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + wg_size - 1) / wg_size;