]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : extend bin bcast for permuted src1 (#19484)
authorGeorgi Gerganov <redacted>
Wed, 11 Feb 2026 05:52:00 +0000 (07:52 +0200)
committerGitHub <redacted>
Wed, 11 Feb 2026 05:52:00 +0000 (07:52 +0200)
* tests : extend bin bcast for permuted src1

* cont : extend bin support

* cont : s0 is always 1

* tests : simplify

ggml/src/ggml-cpu/binary-ops.cpp
ggml/src/ggml-cuda/binbcast.cu
tests/test-backend-ops.cpp

index 14f5b43ae0eb1e8547c23c223c2bcca535f965b3..75e38290015a6fd6137fa09142585051cf27179b 100644 (file)
@@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
     GGML_ASSERT(nb00 == sizeof(src0_t));
 
     const auto [ir0, ir1] = get_thread_range(params, src0);
-    const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
-
-    if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
-        GGML_ASSERT(ggml_are_same_shape(src0, src1));
-    }
+    const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
 
 #ifdef GGML_USE_ACCELERATE
     vDSP_fn_t vDSP_op = nullptr;
@@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
         const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
         const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
-        if (is_src1_contiguous) {
+        if (is_src1_contiguous_rows) {
             // src1 is broadcastable across src0 and dst in i1, i2, i3
             const int64_t nr0 = ne00 / ne10;
 
index 0e6d777b1e64a8b46cdaaa3e65f9c8c8d9028ced..7339fe0c070d3855430593f8a913bc7088b6f778 100644 (file)
@@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t *         src0,
                                    const uint3            ne11,
                                    const uint3            ne12,
                                    const uint3            ne13,
-                                   /*int s0, */ const int s1,
+                                 /*const int              s0,*/
+                                   const int              s1,
                                    const int              s2,
                                    const int              s3,
-                                   /*int s00,*/ const int s01,
+                                   const int              s00,
+                                   const int              s01,
                                    const int              s02,
                                    const int              s03,
-                                   /*int s10,*/ const int s11,
+                                   const int              s10,
+                                   const int              s11,
                                    const int              s12,
                                    const int              s13,
                                    src1_ptrs... src1s) {
@@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t *         src0,
     for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
         const uint32_t i10 = fastmodulo(i0, ne10);
 
-        float result = src0_row ? (float) src0_row[i0] : 0.0f;
+        float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
         if constexpr (sizeof...(src1_ptrs) > 0) {
-            result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+            result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
         } else {
-            result = bin_op(result, (float)src1[i_src1 + i10]);
+            result = bin_op(result, (float)src1[i_src1 + i10*s10]);
         }
 
         dst_row[i0] = (dst_t) result;
@@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t *         src0,
                                            const uint3            ne11,
                                            const uint3            ne12,
                                            const uint3            ne13,
-                                           /*int s0, */ const int s1,
+                                         /*const int              s0,*/
+                                           const int              s1,
                                            const int              s2,
                                            const int              s3,
-                                           /*int s00,*/ const int s01,
+                                           const int              s00,
+                                           const int              s01,
                                            const int              s02,
                                            const int              s03,
-                                           /*int s10,*/ const int s11,
+                                           const int              s10,
+                                           const int              s11,
                                            const int              s12,
                                            const int              s13,
                                            src1_ptrs... src1s) {
@@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t *         src0,
 
     const int i10 = fastmodulo(i0, ne10);
 
-    float result = src0_row ? (float) src0_row[i0] : 0.0f;
+    float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
     if constexpr (sizeof...(src1_ptrs) > 0) {
-        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
     } else {
-        result = bin_op(result, (float)src1[i_src1 + i10]);
+        result = bin_op(result, (float)src1[i_src1 + i10*s10]);
     }
 
     dst_row[i0] = (dst_t) result;
@@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         cnb[3] *= cne[3];
     };
 
-    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
         for (int i = 0; i < 4; i++) {
             if (nr[i] != 1) {
                 break;
@@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         size_t nb12 = cnb1[2];
         size_t nb13 = cnb1[3];
 
-        size_t s0 = nb0 / sizeof(dst_t);
+      //size_t s0 = nb0 / sizeof(dst_t);
         size_t s1 = nb1 / sizeof(dst_t);
         size_t s2 = nb2 / sizeof(dst_t);
         size_t s3 = nb3 / sizeof(dst_t);
@@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
         GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
 
-        GGML_ASSERT(s0 == 1);
-        GGML_ASSERT(s00 == 1);
-        GGML_ASSERT(s10 == 1);
-
         const int block_size = 128;
 
         int64_t hne0 = std::max(ne0 / 2LL, 1LL);
@@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
                 k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
                     src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
                     ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
+                  /*s0,*/ s1,  s2,  s3,
+                    s00, s01, s02, s03,
+                    s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
             } else {
                 k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
                     <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
                                                            ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
-                                                           /* s0, */ s1, s2, s3,
-                                                           /* s00,*/ s01, s02, s03,
-                                                           /* s10,*/ s11, s12, s13);
+                                                         /*s0,*/ s1,  s2,  s3,
+                                                           s00, s01, s02, s03,
+                                                           s10, s11, s12, s13);
             }
         } else {
             const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
             if constexpr (sizeof...(I) > 0) {
                 k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
                     src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
+                  /*s0,*/ s1, s2,  s3,
+                    s00 ,s01, s02, s03,
+                    s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
             } else {
                 k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
                     src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12, s13);
+                  /*s0,*/ s1,  s2,  s3,
+                    s00, s01, s02, s03,
+                    s10, s11, s12, s13);
             }
         }
     }
index 1f40d2fc38688c0ad7f16952aa6dd55cce9e92e1..5d5e44a0c79f16cdc482fcf127a83c98ce52480a 100644 (file)
@@ -2964,11 +2964,12 @@ struct test_bin_bcast : public test_case {
     const std::array<int64_t, 4> ne;
     const std::array<int, 4> nr;
     int nf; // number of fused ops, nf == 1 -> single op (no fusion)
+    bool perm1; // permute src1?
 
     bool run_whole_graph() override { return nf > 1; }
 
     std::string vars() override {
-        return VARS_TO_STR4(type, ne, nr, nf);
+        return VARS_TO_STR5(type, ne, nr, nf, perm1);
     }
 
     size_t op_size(ggml_tensor * t) override {
@@ -2978,8 +2979,9 @@ struct test_bin_bcast : public test_case {
     test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 1, 1},
             std::array<int, 4> nr = {1, 2, 1, 1},
-            int nf = 1)
-        : op(op), type(type), ne(ne), nr(nr), nf(nf) {}
+            int nf = 1,
+            bool perm1 = false)
+        : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         GGML_ASSERT(nf <= 16);
@@ -2989,12 +2991,19 @@ struct test_bin_bcast : public test_case {
 
         ggml_tensor * b[16];
         for (int i = 0; i < nf; ++i) {
-            b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
+            if (perm1) {
+                const int p[4] = { 1, 2, 0, 3 }; // hardcoded for now
+
+                b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]);
+                b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]);
+            } else {
+                b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
+            }
             ggml_set_name(b[i], (std::string("b") + std::to_string(i)).c_str());
         }
 
         // The backward pass supports broadcasting only for GGML_ADD:
-        const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1;
+        const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1 && !perm1;
         if (grad_supported) {
             ggml_set_param(a);
             ggml_set_param(b[0]);
@@ -7477,25 +7486,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         }
     }
 
-    auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
+    auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr, bool perm1 = false) {
         for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
-            test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
+            test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1));
         }
     };
     for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
-        add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
-        add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
-        add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
-        add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
-        add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
-        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
-        add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
-        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
-        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
-        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
-        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
-        add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
-        add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
+        for (bool perm1 : {false, true}) {
+            add_test_bin_bcast(type, {1,  1,   8,   1}, {1,  1, 1, 1}, perm1);
+            add_test_bin_bcast(type, {1,  1,   1,   1}, {32, 1, 1, 1}, perm1);
+            add_test_bin_bcast(type, {1,  1, 320, 320}, {1,  1, 1, 1}, perm1);
+            add_test_bin_bcast(type, {10, 5,   1,   1}, {1,  1, 1, 1}, perm1);
+            add_test_bin_bcast(type, {10, 5,   4,   1}, {1,  1, 1, 1}, perm1);
+            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 1, 1}, perm1);
+            add_test_bin_bcast(type, {10, 5,   4,   3}, {2,  1, 1, 1}, perm1);
+            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  2, 1, 1}, perm1);
+            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 2, 1}, perm1);
+            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 1, 2}, perm1);
+            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 2, 2}, perm1);
+            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  2, 2, 2}, perm1);
+            add_test_bin_bcast(type, {10, 5,   4,   3}, {2,  2, 2, 2}, perm1);
+        }
 
         // test case for k_bin_bcast_unravel in CUDA backend
         add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});