const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
- /*int s0, */ const int s1,
+ /*const int s0,*/
+ const int s1,
const int s2,
const int s3,
- /*int s00,*/ const int s01,
+ const int s00,
+ const int s01,
const int s02,
const int s03,
- /*int s10,*/ const int s11,
+ const int s10,
+ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10);
- float result = src0_row ? (float) src0_row[i0] : 0.0f;
+ float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
- result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
} else {
- result = bin_op(result, (float)src1[i_src1 + i10]);
+ result = bin_op(result, (float)src1[i_src1 + i10*s10]);
}
dst_row[i0] = (dst_t) result;
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
- /*int s0, */ const int s1,
+ /*const int s0,*/
+ const int s1,
const int s2,
const int s3,
- /*int s00,*/ const int s01,
+ const int s00,
+ const int s01,
const int s02,
const int s03,
- /*int s10,*/ const int s11,
+ const int s10,
+ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
const int i10 = fastmodulo(i0, ne10);
- float result = src0_row ? (float) src0_row[i0] : 0.0f;
+ float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
- result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
} else {
- result = bin_op(result, (float)src1[i_src1 + i10]);
+ result = bin_op(result, (float)src1[i_src1 + i10*s10]);
}
dst_row[i0] = (dst_t) result;
cnb[3] *= cne[3];
};
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
- size_t s0 = nb0 / sizeof(dst_t);
+ //size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
- GGML_ASSERT(s0 == 1);
- GGML_ASSERT(s00 == 1);
- GGML_ASSERT(s10 == 1);
-
const int block_size = 128;
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
+ /*s0,*/ s1, s2, s3,
+ s00, s01, s02, s03,
+ s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12, s13);
+ /*s0,*/ s1, s2, s3,
+ s00, s01, s02, s03,
+ s10, s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
+ /*s0,*/ s1, s2, s3,
+ s00 ,s01, s02, s03,
+ s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00,*/ s01, s02, s03,
- /* s10,*/ s11, s12, s13);
+ /*s0,*/ s1, s2, s3,
+ s00, s01, s02, s03,
+ s10, s11, s12, s13);
}
}
}
const std::array<int64_t, 4> ne;
const std::array<int, 4> nr;
int nf; // number of fused ops, nf == 1 -> single op (no fusion)
+ bool perm1; // permute src1?
bool run_whole_graph() override { return nf > 1; }
std::string vars() override {
- return VARS_TO_STR4(type, ne, nr, nf);
+ return VARS_TO_STR5(type, ne, nr, nf, perm1);
}
size_t op_size(ggml_tensor * t) override {
test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 1, 1},
std::array<int, 4> nr = {1, 2, 1, 1},
- int nf = 1)
- : op(op), type(type), ne(ne), nr(nr), nf(nf) {}
+ int nf = 1,
+ bool perm1 = false)
+ : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
GGML_ASSERT(nf <= 16);
ggml_tensor * b[16];
for (int i = 0; i < nf; ++i) {
- b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
+ if (perm1) {
+ const int p[4] = { 1, 2, 0, 3 }; // hardcoded for now
+
+ b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]);
+ b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]);
+ } else {
+ b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
+ }
ggml_set_name(b[i], (std::string("b") + std::to_string(i)).c_str());
}
// The backward pass supports broadcasting only for GGML_ADD:
- const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1;
+ const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1 && !perm1;
if (grad_supported) {
ggml_set_param(a);
ggml_set_param(b[0]);
}
}
- auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
+ auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr, bool perm1 = false) {
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
- test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
+ test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1));
}
};
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
- add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
- add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
- add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
- add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
- add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
- add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
- add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
- add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
- add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
- add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
- add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
- add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
- add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
+ for (bool perm1 : {false, true}) {
+ add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1}, perm1);
+ add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1}, perm1);
+ add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1}, perm1);
+ add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1}, perm1);
+ add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1}, perm1);
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1}, perm1);
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1}, perm1);
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1}, perm1);
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1}, perm1);
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2}, perm1);
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2}, perm1);
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}, perm1);
+ add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}, perm1);
+ }
// test case for k_bin_bcast_unravel in CUDA backend
add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});