]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml webgpu: minor set rows optimization (#16810)
authorReese Levine <redacted>
Wed, 5 Nov 2025 09:27:42 +0000 (01:27 -0800)
committerGitHub <redacted>
Wed, 5 Nov 2025 09:27:42 +0000 (10:27 +0100)
* Add buffer label and enable dawn-specific toggles to turn off some checks

* Minor set_rows optimization (#4)

* updated optimization, fixed errors

* non vectorized version now dispatches one thread per element

* Simplify

* Change logic for set_rows pipelines

---------

Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Reese Levine <redacted>
* Comment on dawn toggles

* Remove some comments

* Implement overlap binary operators

* Revert "Implement overlap binary operators"

This reverts commit ed710b36f51ab3f53fa13db15c1685dc8678a32a.

* Disable support for non-contiguous binary_op tensors and leave note for future support

---------

Co-authored-by: neha-ha <redacted>
Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Neha Abbas <redacted>
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl [deleted file]

index 05e16cd432ad3c16125a9f442389df88e476ef62..1a15756731580f2868543c337814a45a8d0f8ea6 100644 (file)
@@ -248,7 +248,7 @@ struct webgpu_context_struct {
 
     webgpu_pipeline memset_pipeline;
     webgpu_pipeline mul_mat_pipeline[30][2];
-    webgpu_pipeline set_rows_pipeline;
+    webgpu_pipeline set_rows_pipeline[1][2];  // dst->type, vectorized
     webgpu_pipeline get_rows_pipeline[30];
     webgpu_pipeline get_rows_f32_no_vec_pipeline;
     webgpu_pipeline cpy_pipeline[2][2];          // src type, dst type
@@ -309,10 +309,12 @@ struct ggml_backend_webgpu_context {
 struct ggml_backend_webgpu_buffer_context {
     webgpu_context webgpu_ctx;
     wgpu::Buffer   buffer;
+    std::string    label;
 
-    ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
+    ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) :
         webgpu_ctx(std::move(ctx)),
-        buffer(std::move(buf)) {}
+        buffer(std::move(buf)),
+        label(std::move(lbl)) {}
 };
 
 /* End struct definitions */
@@ -764,10 +766,20 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
         { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
     };
 
-    size_t   max_wg_size = ctx->max_wg_size_x;
-    uint32_t wg_x        = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
+    size_t max_wg_size = ctx->max_wg_size_x;
+
+    int             vectorized = src->ne[0] % 4 == 0;
+    webgpu_pipeline pipeline   = ctx->set_rows_pipeline[0][vectorized];
+    uint32_t        threads;
+    if (vectorized) {
+        threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
+    } else {
+        threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
+    }
 
-    return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs);
+    uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size;
+
+    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs);
 }
 
 static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
@@ -1336,11 +1348,11 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
 
     WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
 
-    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
-                                                                 << offset << ", " << size << ")");
-
     ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
 
+    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
+                                                                 << ", " << offset << ", " << size << ")");
+
     size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
 
     // This is a trick to set all bytes of a u32 to the same 1 byte value.
@@ -1354,12 +1366,13 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
                                                   const void *          data,
                                                   size_t                offset,
                                                   size_t                size) {
-    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
-                                                              << offset << ", " << size << ")");
     WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
     ggml_backend_webgpu_buffer_context * buf_ctx    = (ggml_backend_webgpu_buffer_context *) buffer->context;
     webgpu_context                       webgpu_ctx = buf_ctx->webgpu_ctx;
 
+    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
+                                                              << ", " << offset << ", " << size << ")");
+
     size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
 
     webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
@@ -1397,12 +1410,12 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
                                                   void *                data,
                                                   size_t                offset,
                                                   size_t                size) {
-    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
-                                                              << offset << ", " << size << ")");
     WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
-    ggml_backend_webgpu_buffer_context * buf_ctx    = (ggml_backend_webgpu_buffer_context *) buffer->context;
-    webgpu_context                       webgpu_ctx = buf_ctx->webgpu_ctx;
-    wgpu::Device                         device     = webgpu_ctx->device;
+    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
+    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
+                                                              << ", " << offset << ", " << size << ")");
+    webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
+    wgpu::Device   device     = webgpu_ctx->device;
 
     size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
 
@@ -1473,16 +1486,20 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
 
 static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
                                                                           size_t                     size) {
-    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
+    static std::atomic<int> buffer_count;
+    int                     buffer_id = buffer_count++;
+    std::string             buf_name  = "tensor_buf" + std::to_string(buffer_id);
+    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
     ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
 
     wgpu::Buffer buf;
     ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
                               (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
                               wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
-                              "allocated_buffer");
+                              buf_name.c_str());
 
-    ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
+    ggml_backend_webgpu_buffer_context * buf_ctx =
+        new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name);
 
     return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
 }
@@ -1613,8 +1630,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
-    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
-                                ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16,
+                                "set_rows_f16", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16_vec,
+                                "set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
 }
 
 static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
@@ -1950,8 +1969,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
+            // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
+            // see https://github.com/ggml-org/llama.cpp/pull/16857
             supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
-                          (src1->type == op->type);
+                          (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
             break;
         case GGML_OP_CPY:
         case GGML_OP_CONT:
@@ -2129,6 +2150,19 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
     required_features.push_back(wgpu::FeatureName::TimestampQuery);
 #endif
 
+    // Enable Dawn-specific toggles to increase native performance
+    // TODO: Don't enable for WASM builds, they won't have an effect anyways
+    // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
+    //       only for native performance?
+    const char * const deviceEnabledToggles[]  = { "skip_validation", "disable_robustness", "disable_workgroup_init",
+                                                   "disable_polyfills_on_integer_div_and_mod" };
+    const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
+    wgpu::DawnTogglesDescriptor deviceTogglesDesc;
+    deviceTogglesDesc.enabledToggles      = deviceEnabledToggles;
+    deviceTogglesDesc.enabledToggleCount  = 4;
+    deviceTogglesDesc.disabledToggles     = deviceDisabledToggles;
+    deviceTogglesDesc.disabledToggleCount = 1;
+
     wgpu::DeviceDescriptor dev_desc;
     dev_desc.requiredLimits       = &ctx->limits;
     dev_desc.requiredFeatures     = required_features.data();
@@ -2146,6 +2180,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
             GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
                        std::string(message).c_str());
         });
+    dev_desc.nextInChain = &deviceTogglesDesc;
     ctx->instance.WaitAny(ctx->adapter.RequestDevice(
                               &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
                               [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
@@ -2243,11 +2278,18 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
     ctx.name         = GGML_WEBGPU_NAME;
     ctx.device_count = 1;
 
+    const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
+
+    wgpu::DawnTogglesDescriptor instanceTogglesDesc;
+    instanceTogglesDesc.enabledToggles     = instanceEnabledToggles;
+    instanceTogglesDesc.enabledToggleCount = 1;
     wgpu::InstanceDescriptor               instance_descriptor{};
     std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
     instance_descriptor.requiredFeatures                     = instance_features.data();
     instance_descriptor.requiredFeatureCount                 = instance_features.size();
-    webgpu_ctx->instance                                     = wgpu::CreateInstance(&instance_descriptor);
+    instance_descriptor.nextInChain                          = &instanceTogglesDesc;
+
+    webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
     GGML_ASSERT(webgpu_ctx->instance != nullptr);
 
     static ggml_backend_reg reg = {
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl
new file mode 100644 (file)
index 0000000..fca3be6
--- /dev/null
@@ -0,0 +1,112 @@
+#define(VARIANTS)
+
+[
+  {
+    "SHADER_SUFFIX": "f16_vec",
+    "REPLS": {
+      "TYPE" : "vec4<f32>",
+      "DST_TYPE": "vec4<f16>",
+      "VEC_SIZE": 4
+    }
+  },
+  {
+    "SHADER_SUFFIX": "f16",
+    "REPLS": {
+      "TYPE" : "f32",
+      "DST_TYPE": "f16",
+      "VEC_SIZE": 1
+    }
+  }
+]
+
+#end(VARIANTS)
+
+#define(SHADER)
+
+enable f16;
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<{{TYPE}}>;
+
+@group(0) @binding(1)
+var<storage, read_write> idx: array<u32>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<{{DST_TYPE}}>;
+
+@group(0) @binding(3)
+var<storage, read_write> error: atomic<u32>;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_idx: u32, // in elements
+    offset_dst: u32, // in elements
+
+    // Strides (in elements)
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_idx0: u32,
+    stride_idx1: u32,
+    stride_idx2: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // Shape of src
+    ne0: u32,
+    n_rows: u32,
+    ne2: u32,
+    ne3: u32,
+
+    // Shape of idx
+    idx1: u32,
+    idx2: u32,
+};
+
+@group(0) @binding(4)
+var<uniform> params: Params;
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) {
+        return;
+    }
+
+    // getting the row from gid
+    let elems_per_row = params.ne0 / {{VEC_SIZE}};
+    var i = gid.x / elems_per_row;
+
+    let i_src3 = i / (params.ne2 * params.n_rows);
+
+    i = i % (params.ne2 * params.n_rows);
+    let i_src2 = i / params.n_rows;
+    let i_src1 = i % params.n_rows;
+
+    let i_idx2 = i_src3 % params.idx2;
+    let i_idx1 = i_src2 % params.idx1;
+    let i_idx0 = i_src1;
+
+    let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
+
+    let idx_high_val = idx[idx_high];
+    let idx_low_val = idx[idx_high + 1];
+
+    if (idx_low_val != 0) {
+        // Upper bits of index are not zero, output will be incorrect
+        atomicStore(&error, 1);
+        return;
+    }
+
+    let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
+    let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
+
+    let col_idx = (gid.x % elems_per_row);
+    dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]);
+}
+
+#end(SHADER)
+
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
deleted file mode 100644 (file)
index 3567713..0000000
+++ /dev/null
@@ -1,81 +0,0 @@
-enable f16;
-
-@group(0) @binding(0)
-var<storage, read_write> src: array<f32>;
-
-@group(0) @binding(1)
-var<storage, read_write> idx: array<u32>;
-
-@group(0) @binding(2)
-var<storage, read_write> dst: array<f16>;
-
-@group(0) @binding(3)
-var<storage, read_write> error: atomic<u32>;
-
-struct Params {
-    offset_src: u32, // in elements
-    offset_idx: u32, // in elements
-    offset_dst: u32, // in elements
-
-    // Strides (in elements)
-    stride_src1: u32,
-    stride_src2: u32,
-    stride_src3: u32,
-
-    stride_idx0: u32,
-    stride_idx1: u32,
-    stride_idx2: u32,
-
-    stride_dst1: u32,
-    stride_dst2: u32,
-    stride_dst3: u32,
-
-    // Shape of src
-    ne0: u32,
-    n_rows: u32,
-    ne2: u32,
-    ne3: u32,
-
-    // Shape of idx
-    idx1: u32,
-    idx2: u32,
-};
-
-@group(0) @binding(4)
-var<uniform> params: Params;
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
-        return;
-    }
-    var i = gid.x;
-    let i_src3 = i / (params.ne2 * params.n_rows);
-
-    i = i % (params.ne2 * params.n_rows);
-    let i_src2 = i / params.n_rows;
-    let i_src1 = i % params.n_rows;
-
-    let i_idx2 = i_src3 % params.idx2;
-    let i_idx1 = i_src2 % params.idx1;
-    let i_idx0 = i_src1;
-
-    let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
-
-    let idx_high_val = idx[idx_high];
-    let idx_low_val = idx[idx_high + 1];
-
-    if (idx_low_val != 0) {
-        // Upper bits of index are not zero, output will be incorrect
-        atomicStore(&error, 1);
-        return;
-    }
-
-    let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
-    let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
-
-    for (var i: u32 = 0; i < params.ne0; i++) {
-      dst[i_dst_row + i] = f16(src[i_src_row + i]);
-    }
-}