From: Georgi Gerganov Date: Wed, 11 Feb 2026 05:52:00 +0000 (+0200) Subject: ggml : extend bin bcast for permuted src1 (llama/19484) X-Git-Tag: upstream/1.8.3+155~25 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=350435805607207934b4e79554012649e8d0cc53;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp ggml : extend bin bcast for permuted src1 (llama/19484) * tests : extend bin bcast for permuted src1 * cont : extend bin support * cont : s0 is always 1 * tests : simplify --- diff --git a/ggml/src/ggml-cpu/binary-ops.cpp b/ggml/src/ggml-cpu/binary-ops.cpp index 14f5b43a..75e38290 100644 --- a/ggml/src/ggml-cpu/binary-ops.cpp +++ b/ggml/src/ggml-cpu/binary-ops.cpp @@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds GGML_ASSERT(nb00 == sizeof(src0_t)); const auto [ir0, ir1] = get_thread_range(params, src0); - const bool is_src1_contiguous = (nb10 == sizeof(src1_t)); - - if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - } + const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1); #ifdef GGML_USE_ACCELERATE vDSP_fn_t vDSP_op = nullptr; @@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - if (is_src1_contiguous) { + if (is_src1_contiguous_rows) { // src1 is broadcastable across src0 and dst in i1, i2, i3 const int64_t nr0 = ne00 / ne10; diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 0e6d777b..7339fe0c 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + /*const int s0,*/ + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0, for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { const uint32_t i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } dst_row[i0] = (dst_t) result; @@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + /*const int s0,*/ + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const int i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } dst_row[i0] = (dst_t) result; @@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * cnb[3] *= cne[3]; }; - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { for (int i = 0; i < 4; i++) { if (nr[i] != 1) { break; @@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * size_t nb12 = cnb1[2]; size_t nb13 = cnb1[3]; - size_t s0 = nb0 / sizeof(dst_t); + //size_t s0 = nb0 / sizeof(dst_t); size_t s1 = nb1 / sizeof(dst_t); size_t s2 = nb2 / sizeof(dst_t); size_t s3 = nb3 / sizeof(dst_t); @@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s00 == 1); - GGML_ASSERT(s10 == 1); - const int block_size = 128; int64_t hne0 = std::max(ne0 / 2LL, 1LL); @@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * k_bin_bcast_unravel<<>>( src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast_unravel <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); } } else { const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); if constexpr (sizeof...(I) > 0) { k_bin_bcast<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + /*s0,*/ s1, s2, s3, + s00 ,s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); } } }