]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
fix vulkan ggml_acc only works in 3d but not 4d (llama/19426)
authorymcki <redacted>
Fri, 13 Feb 2026 12:31:37 +0000 (20:31 +0800)
committerGeorgi Gerganov <redacted>
Sun, 15 Feb 2026 19:44:37 +0000 (21:44 +0200)
* fix vulkan ggml_acc only works in 3d but not 4d

* removed clamp in test_acc_block

* use the correct stride and its test case

* cuda : fix "supports op" condition

* change src0 to src1 in ggml_vk_acc. Update acc.comp with jeffbolznv\'s suggestion except to keep the boundary check

* version without boundary check

* revert back to boundary check version

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/acc.comp

index 7dc688483ad19fd941f66ecbd48e01a82237e63a..85ce96958fa0c050e4a5349d7366bbf7bbf250f4 100644 (file)
@@ -4822,8 +4822,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CONV_2D_DW:
         case GGML_OP_CONV_TRANSPOSE_2D:
         case GGML_OP_POOL_2D:
-        case GGML_OP_ACC:
             return true;
+        case GGML_OP_ACC:
+            // TODO: extend support like so:
+            //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
+            return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
         case GGML_OP_SUM:
             return ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_TOP_K:
index 72097ffd0ffc73ae6e09ff7783ed20904827e0ae..e5dcd3cbda276cb02bc729f524fb04c5642da59d 100644 (file)
@@ -9801,16 +9801,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
 
-    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
-    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
-    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
-    int offset = dst->op_params[3] / 4; // offset in bytes
+    int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
+    int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
+    int offset = dst->op_params[3] / src0_type_size; // offset in bytes
 
     ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
         (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] /  dst_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
         0,
         0.0f, 0.0f, offset,
     });
index 5084a70ed49f76ca8137d6c0fd8b068ae5f8081a..3d61168b56f09f97a76042f62370279dfe73a204 100644 (file)
@@ -13,17 +13,18 @@ void main() {
 
     const uint offset = p.param3;
     const uint src1_i = idx - offset;
-    const uint oz = src1_i / p.nb02;
-    const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
-    const uint ox = src1_i % p.nb01;
+    const uint i3 = src1_i / p.nb03;
+    const uint rem2 = src1_i - i3 * p.nb03;
+    const uint i2 = rem2 / p.nb02;
+    const uint rem1 = rem2 - i2 * p.nb02;
+    const uint i1 = rem1 / p.nb01;
+    const uint i0 = rem1 % p.nb01;
 
     uint i00, i01, i02, i03;
-    get_indices(idx, i00, i01, i02, i03);
 
-    if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
-        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
+    if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
+        data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
     } else {
-        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
+        data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
     }
 }
-