]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml webgpu: add support for soft_max, optimize rms_norm (#16357)
authorReese Levine <redacted>
Thu, 2 Oct 2025 18:00:31 +0000 (11:00 -0700)
committerGitHub <redacted>
Thu, 2 Oct 2025 18:00:31 +0000 (11:00 -0700)
* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* Update tests/test-backend-ops.cpp

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/include/ggml.h
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl [new file with mode: 0644]
ggml/src/ggml.c
tests/test-backend-ops.cpp

index f65eb75e298fa99dc655f3852606825ea2c81a06..60c6b63d05978fa2a2c457d8db8e79371040c0f3 100644 (file)
@@ -1630,6 +1630,13 @@ extern "C" {
             float                 scale,
             float                 max_bias);
 
+    GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * mask,
+            float                 scale,
+            float                 max_bias);
+
     GGML_API void ggml_soft_max_add_sinks(
             struct ggml_tensor * a,
             struct ggml_tensor * sinks);
index 93200a4d29f53bc5c5f7043b0ec246f8d03f99ba..de68c5689bba730be14aa6a83e4f070a40db5b36 100644 (file)
@@ -28,6 +28,7 @@
 /* Constants */
 
 #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     16
+#define WEBGPU_WAIT_ANY_BATCH_SIZE           64
 #define WEBGPU_MUL_MAT_WG_SIZE               64
 #define WEBGPU_NUM_PARAM_BUFS                100
 #define WEBGPU_PARAMS_BUF_SIZE_BYTES         128  // enough for 32 parameters
@@ -35,6 +36,9 @@
 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
 #define WEBGPU_STORAGE_BUF_BINDING_MULT      4  // a storage buffer binding size must be a multiple of 4
 
+// For operations which process a row in parallel, this seems like a reasonable default
+#define WEBGPU_ROW_SPLIT_WG_SIZE 64
+
 /* End Constants */
 
 // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
@@ -130,15 +134,16 @@ struct webgpu_context_struct {
     wgpu::ComputePipeline set_rows_pipeline;
     wgpu::ComputePipeline get_rows_pipeline[30];
     wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
-    wgpu::ComputePipeline cpy_pipeline[2][2];      // src type, dst type
-    wgpu::ComputePipeline add_pipeline[2][2];      // type, inplace
-    wgpu::ComputePipeline sub_pipeline[2][2];      // type, inplace
-    wgpu::ComputePipeline mul_pipeline[2][2];      // type, inplace
-    wgpu::ComputePipeline div_pipeline[2][2];      // type, inplace
-    wgpu::ComputePipeline rms_norm_pipeline[2];    // inplace
-    wgpu::ComputePipeline rope_pipeline[2][2][2];  // type, ff, inplace
-    wgpu::ComputePipeline glu_pipeline[7][2][2];   // glu-op, type, split
-    wgpu::ComputePipeline scale_pipeline[2];       // inplace
+    wgpu::ComputePipeline cpy_pipeline[2][2];          // src type, dst type
+    wgpu::ComputePipeline add_pipeline[2][2];          // type, inplace
+    wgpu::ComputePipeline sub_pipeline[2][2];          // type, inplace
+    wgpu::ComputePipeline mul_pipeline[2][2];          // type, inplace
+    wgpu::ComputePipeline div_pipeline[2][2];          // type, inplace
+    wgpu::ComputePipeline rms_norm_pipeline[2];        // inplace
+    wgpu::ComputePipeline rope_pipeline[2][2][2];      // type, ff, inplace
+    wgpu::ComputePipeline glu_pipeline[7][2][2];       // glu-op, type, split
+    wgpu::ComputePipeline scale_pipeline[2];           // inplace
+    wgpu::ComputePipeline soft_max_pipeline[3][2][2];  // (no_mask, f32_mask, f16_mask), has_sink, inplace
 
     size_t memset_bytes_per_thread;
 
@@ -256,8 +261,12 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
                                            }),
             UINT64_MAX);
     } else {
-        // existing callbacks, wait on them
-        ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
+        // WebGPU implementations may limit the number of futures that can be waited on at once,
+        // so wait in batches (64 is what Dawn supports).
+        for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) {
+            size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size());
+            ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX);
+        }
         ctx->callback_futures.clear();
     }
 }
@@ -726,9 +735,7 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    size_t   max_wg_size = ctx->max_wg_size_x;
-    uint32_t wg_x        = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
-    ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
+    ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src),
                                           ggml_op_name(dst->op));
 }
 
@@ -912,6 +919,79 @@ static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tens
                                           ggml_op_name(dst->op));
 }
 
+static void ggml_webgpu_soft_max(webgpu_context & ctx,
+                                 ggml_tensor *    src0,
+                                 ggml_tensor *    src1,
+                                 ggml_tensor *    src2,
+                                 ggml_tensor *    dst) {
+    const int inplace   = ggml_webgpu_tensor_equal(src0, dst);
+    const int mask_type = (src1 != nullptr) ? src1->type : 2;  // use 2 for no mask here
+    const int has_sink  = (src2 != nullptr);
+    float     max_bias;
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+    float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
+    float m0          = powf(2.0f, -(max_bias) / n_head_log2);
+    float m1          = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    std::vector<uint32_t> params = {
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+        mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
+        has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+        mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
+        mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
+        mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
+        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+        (uint32_t) ggml_nelements(dst),
+        (uint32_t) src0->ne[0],
+        (uint32_t) src0->ne[1],
+        (uint32_t) src0->ne[2],
+        mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
+        mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
+        *(uint32_t *) dst->op_params,  // scale
+        *(uint32_t *) &max_bias,
+        *(uint32_t *) &n_head_log2,
+        *(uint32_t *) &m0,
+        *(uint32_t *) &m1
+    };
+
+    std::vector<wgpu::BindGroupEntry> entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src0),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) }
+    };
+    uint32_t binding_num = 1;
+    if (mask_type < 2) {
+        entries.push_back({ .binding = binding_num,
+                            .buffer  = ggml_webgpu_tensor_buf(src1),
+                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
+                            .size    = ggml_webgpu_tensor_binding_size(ctx, src1) });
+        binding_num++;
+    }
+    if (has_sink) {
+        entries.push_back({ .binding = binding_num,
+                            .buffer  = ggml_webgpu_tensor_buf(src2),
+                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src2),
+                            .size    = ggml_webgpu_tensor_binding_size(ctx, src2) });
+        binding_num++;
+    }
+    if (!inplace) {
+        entries.push_back({ .binding = binding_num,
+                            .buffer  = ggml_webgpu_tensor_buf(dst),
+                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
+    }
+
+    ggml_backend_webgpu_build_and_enqueue(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries,
+                                          ggml_nrows(dst), ggml_op_name(dst->op));
+}
+
 // Returns true if node has enqueued work into the queue, false otherwise
 static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
     if (ggml_is_empty(node)) {
@@ -1237,11 +1317,11 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
     return reinterpret_cast<ggml_guid_t>((void *) guid_str);
 }
 
-// The max workgroup size is a common constant
-static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
+// Workgroup size is a common constant
+static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
     std::vector<wgpu::ConstantEntry> constants(1);
     constants[0].key   = "wg_size";
-    constants[0].value = webgpu_ctx->max_wg_size_x;
+    constants[0].value = wg_size;
     return constants;
 }
 
@@ -1309,11 +1389,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
 
 static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
-                                ggml_webgpu_max_wg_size_entry(webgpu_ctx));
+                                ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
 }
 
 static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
                                 "get_rows_f32_vec", constants);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
@@ -1363,7 +1443,7 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
                                 wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
@@ -1375,7 +1455,7 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
                                 constants);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
@@ -1387,7 +1467,7 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
                                 constants);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
@@ -1399,7 +1479,7 @@ static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
                                 constants);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
@@ -1411,7 +1491,7 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
                                 constants);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
@@ -1423,7 +1503,7 @@ static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
                                 constants);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
@@ -1431,7 +1511,7 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
                                 "rope_f32", constants);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
@@ -1451,7 +1531,7 @@ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
     // reglu
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
                                 wgsl_reglu_f32, "reglu_f32", constants);
@@ -1505,13 +1585,43 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
 }
 
 static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
                                 constants);
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
                                 "scale_f32_inplace", constants);
 }
 
+static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
+    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32,
+                                "soft_max_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace,
+                                "soft_max_f32_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][0], wgsl_soft_max_f32_sink,
+                                "soft_max_f32_sink", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][1],
+                                wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][0], wgsl_soft_max_f32_mask_f32,
+                                "soft_max_f32_mask_f32", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][1],
+                                wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][0], wgsl_soft_max_f32_mask_f16,
+                                "soft_max_f32_mask_f16", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][1],
+                                wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][0],
+                                wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][1],
+                                wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace",
+                                constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][0],
+                                wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
+    ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][1],
+                                wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace",
+                                constants);
+}
+
 static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
     GGML_UNUSED(params);
 
@@ -1593,6 +1703,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
 
     ggml_tensor * src0 = op->src[0];
     ggml_tensor * src1 = op->src[1];
+    ggml_tensor * src2 = op->src[2];
 
     // on smaller devices (or CI), tensors may be larger than the max storage buffer size
     if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
@@ -1623,7 +1734,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
                           (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
             break;
         case GGML_OP_SET_ROWS:
-            supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
+            supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
             break;
         case GGML_OP_GET_ROWS:
             if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
@@ -1698,13 +1809,25 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
         default:
             break;
     }
-#ifdef GGML_WEBGPU_DEBUG
+    if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
+        (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
+        (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
+        (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
+        supports_op = false;
+        WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
+    }
+
     if (!supports_op) {
-        WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
-                                           << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
-                                           << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
+        WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
+                         << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
+                         << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
+                         << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
+    } else {
+        WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
+                         << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
+                         << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
+                         << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
     }
-#endif
     return supports_op;
 }
 
index a275eeb9783daa4eaee4fb2e9e46f5da6815d4d2..4f72bb1c851ec33d622ee1bcc81cc22236d5369a 100644 (file)
@@ -71,14 +71,14 @@ var<storage, read_write> src: array<f32>;
 DECLS
 
 override wg_size: u32;
+var<workgroup> scratch: array<f32, wg_size>;
+
 @compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
-        return;
-    }
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+        @builtin(local_invocation_id) lid: vec3<u32>) {
 
     // one thread per row
-    var i = gid.x;
+    var i = wid.x;
     let i3 = i / (params.ne2 * params.ne1);
     i = i % (params.ne2 * params.ne1);
     let i2 = i / params.ne1;
@@ -86,13 +86,38 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
     let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
     let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
 
+    let elems = (params.ne0 + wg_size - 1) / wg_size;
+
     var sum = 0.0f;
-    for (var j: u32 = 0; j < params.ne0; j++) {
-        sum += src[i_src_row + j] * src[i_src_row + j];
+    var col = lid.x;
+    for (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        sum += pow(src[i_src_row + col], 2.0);
+        col += wg_size;
     }
+
+    scratch[lid.x] = sum;
+    workgroupBarrier();
+    var offset = wg_size / 2;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            scratch[lid.x] += scratch[lid.x + offset];
+        }
+        offset = offset / 2;
+        workgroupBarrier();
+    }
+    sum = scratch[0];
+
     let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
-    for (var j: u32 = 0; j < params.ne0; j++) {
-        update(i_src_row + j, i_dst_row + j, scale);
+    col = lid.x;
+    for (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        update(i_src_row + col, i_dst_row + col, scale);
+        col += wg_size;
     }
 }
 #end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl
new file mode 100644 (file)
index 0000000..64ab576
--- /dev/null
@@ -0,0 +1,344 @@
+#define(VARIANTS)
+[
+  {
+    "SHADER_NAME": "soft_max_f32",
+    "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_inplace",
+    "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_sink",
+    "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_sink_inplace",
+    "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_mask_f32",
+    "REPLS": {
+      "MASK_TYPE" : "f32",
+    },
+    "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_mask_f32_inplace",
+    "REPLS": {
+      "MASK_TYPE" : "f32",
+    },
+    "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_mask_f16",
+    "REPLS": {
+      "MASK_TYPE" : "f16",
+    },
+    "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_mask_f16_inplace",
+    "REPLS": {
+      "MASK_TYPE" : "f16",
+    },
+    "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_mask_f32_sink",
+    "REPLS": {
+      "MASK_TYPE" : "f32",
+    },
+    "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
+    "REPLS": {
+      "MASK_TYPE" : "f32",
+    },
+    "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_mask_f16_sink",
+    "REPLS": {
+      "MASK_TYPE" : "f16",
+    },
+    "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
+  },
+  {
+    "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
+    "REPLS": {
+      "MASK_TYPE" : "f16",
+    },
+    "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
+  }
+]
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(BASE_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#enddecl(BASE_BINDINGS)
+
+#decl(BASE_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<uniform> params: Params;
+#enddecl(BASE_BINDINGS_INPLACE)
+
+#decl(SINK_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#enddecl(SINK_BINDINGS)
+
+#decl(SINK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#enddecl(SINK_BINDINGS_INPLACE)
+
+#decl(MASK_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#enddecl(MASK_BINDINGS)
+
+#decl(MASK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#enddecl(MASK_BINDINGS_INPLACE)
+
+#decl(MASK_SINK_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(3)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(4)
+var<uniform> params: Params;
+#enddecl(MASK_SINK_BINDINGS)
+
+#decl(MASK_SINK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#enddecl(MASK_SINK_BINDINGS_INPLACE)
+
+#decl(NOT_INPLACE)
+fn inter_value(i: u32) -> f32 {
+    return dst[i];
+}
+
+fn update(i: u32, val: f32) {
+    dst[i] = val;
+}
+#enddecl(NOT_INPLACE)
+
+#decl(INPLACE)
+fn inter_value(i: u32) -> f32 {
+    return src[i];
+}
+
+fn update(i: u32, val: f32) {
+    src[i] = val;
+}
+#enddecl(INPLACE)
+
+#decl(NO_MASK)
+fn mask_val(i: u32) -> f32 {
+    return 0.0;
+}
+#enddecl(NO_MASK)
+
+#decl(MASK)
+fn mask_val(i: u32) -> f32 {
+    return f32(mask[i]);
+}
+#enddecl(MASK)
+
+#decl(NO_SINK)
+fn lower_max_bound(i2: u32) -> f32 {
+    return -1e30;
+}
+
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+    return val;
+}
+#enddecl(NO_SINK)
+
+#decl(SINK)
+fn lower_max_bound(i2: u32) -> f32 {
+    return sinks[params.offset_sinks + i2];
+}
+
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+    return val + exp(sinks[params.offset_sinks + i2] - max_val);
+}
+#enddecl(SINK)
+
+#end(DECLS)
+
+#define(SHADER)
+enable f16;
+
+struct Params {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_sinks: u32,
+    offset_dst: u32,
+
+    // Strides (in elements)
+    stride_src01: u32,
+    stride_src02: u32,
+    stride_src03: u32,
+
+    stride_src11: u32,
+    stride_src12: u32,
+    stride_src13: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // shape of src0/dst
+    ne: u32,
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    // shape of src1
+    ne12: u32,
+    ne13: u32,
+
+    scale: f32,
+    max_bias: f32,
+    n_head_log2: f32,
+    m0: f32,
+    m1: f32,
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+DECLS
+
+const CACHE_SIZE: u32 = 16;
+
+override wg_size: u32;
+var<workgroup> scratch: array<f32, wg_size>;
+
+@compute @workgroup_size(wg_size)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+        @builtin(local_invocation_id) lid: vec3<u32>) {
+
+    var i = wid.x;
+    let i3 = i / (params.ne2 * params.ne1);
+    i = i % (params.ne2 * params.ne1);
+    let i2 = i / params.ne1;
+    let i1 = i % params.ne1;
+    let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+    let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
+    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+    let elems = (params.ne0 + wg_size - 1) / wg_size;
+
+    let head = f32(i2);
+    let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
+
+    var cache: array<f32, CACHE_SIZE>;
+
+    var max_val = lower_max_bound(i2);
+    var col = lid.x;
+    for (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
+        max_val = max(max_val, val);
+        if (col < CACHE_SIZE) {
+            cache[col] = val;
+        }
+        col += wg_size;
+    }
+
+    scratch[lid.x] = max_val;
+    workgroupBarrier();
+    var offset = wg_size / 2;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
+        }
+        offset = offset / 2;
+        workgroupBarrier();
+    }
+    let row_max = scratch[0];
+
+    var sum = 0.0f;
+    col = lid.x;
+    for (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
+                         cache[col], col < CACHE_SIZE);
+        let ex = exp(val - row_max);
+        sum += ex;
+        if (col < CACHE_SIZE) {
+            cache[col] = ex;
+        } else {
+            update(i_dst_row + col, ex);
+        }
+        col += wg_size;
+    }
+
+    scratch[lid.x] = sum;
+    workgroupBarrier();
+    offset = wg_size / 2;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            scratch[lid.x] += scratch[lid.x + offset];
+        }
+        offset = offset / 2;
+        workgroupBarrier();
+    }
+    let row_sum = add_sinks(scratch[0], i2, row_max);
+
+    let sum_recip = 1.0 / row_sum;
+    col = lid.x;
+    for  (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
+        col += wg_size;
+    }
+}
+#end(SHADER)
index 7d50b42a37d45d4508bfb261e619a74d391d0bbf..2bce1375ba3c089a458090e62e6df3cb9d7f69be 100644 (file)
@@ -3852,6 +3852,15 @@ struct ggml_tensor * ggml_soft_max_ext(
     return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
 }
 
+struct ggml_tensor * ggml_soft_max_ext_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * mask,
+        float                 scale,
+        float                 max_bias) {
+    return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);
+}
+
 void ggml_soft_max_add_sinks(
         struct ggml_tensor * a,
         struct ggml_tensor * sinks) {
index 62d815cc26808e553eb7cd9878d289b74f11d30f..c1e45972e54cae9c8f8ad84e9a00f698fea1302c 100644 (file)
@@ -3752,9 +3752,10 @@ struct test_soft_max : public test_case {
     const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
     const float scale;
     const float max_bias;
+    const bool inplace;
 
     std::string vars() override {
-        return VARS_TO_STR8(type, ne, mask, sinks, m_prec, nr23, scale, max_bias);
+        return VARS_TO_STR9(type, ne, mask, sinks, m_prec, nr23, scale, max_bias, inplace);
     }
 
     // the 1024 test with bias occasionally fails:
@@ -3770,8 +3771,9 @@ struct test_soft_max : public test_case {
             ggml_type m_prec = GGML_TYPE_F32,
             std::array<int64_t, 2> nr23 = {1, 1},
             float scale = 1.0f,
-            float max_bias = 0.0f)
-        : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
+            float max_bias = 0.0f,
+            bool inplace = false)
+        : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias), inplace(inplace) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
@@ -3790,7 +3792,12 @@ struct test_soft_max : public test_case {
             ggml_set_name(sinks, "sinks");
         }
 
-        ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
+        ggml_tensor * out;
+        if (inplace) {
+            out = ggml_soft_max_ext_inplace(ctx, a, mask, scale, max_bias);
+        } else {
+            out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
+        }
         ggml_soft_max_add_sinks(out, sinks);
         ggml_set_name(out, "out");
 
@@ -6562,6 +6569,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                     }
                 }
             }
+            // inplace tests
+            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f, true));
+            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f, true));
         }
     }
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));