const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
- int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
- // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
- int offset = dst->op_params[3] / 4; // offset in bytes
+ int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
+ int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
+ int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
+ int offset = dst->op_params[3] / src0_type_size; // offset in bytes
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
(uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
0,
0.0f, 0.0f, offset,
});
const ggml_type type;
const std::array<int64_t, 4> ne_a;
const std::array<int64_t, 4> ne_b;
+ const int64_t stride_dim;
std::string vars() override {
- return VARS_TO_STR3(type, ne_a, ne_b);
+ return VARS_TO_STR4(type, ne_a, ne_b, stride_dim);
}
test_acc(ggml_type type = GGML_TYPE_F32,
- std::array<int64_t, 4> ne_a = {256, 17, 1, 1},
- std::array<int64_t, 4> ne_b = {256, 16, 1, 1})
- : type(type), ne_a(ne_a), ne_b(ne_b) {}
+ std::array<int64_t, 4> ne_a = {256, 17, 2, 3},
+ std::array<int64_t, 4> ne_b = {256, 16, 2, 3},
+ uint64_t stride_dim = -1)
+ : type(type), ne_a(ne_a), ne_b(ne_b), stride_dim(stride_dim) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_param(a);
ggml_set_name(a, "a");
- ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
- ggml_set_param(b);
+ ggml_tensor * b;
+ if (stride_dim == 1 || stride_dim == 2 || stride_dim == 3) {
+ // Create a larger tensor and take a view at a non-zero offset.
+ // This tests that the backend correctly handles b's data offset
+ std::array<int64_t, 4> ne_b_pad = {ne_b[0], ne_b[1], ne_b[2], ne_b[3]};
+ ne_b_pad[stride_dim] += 1;
+ ggml_tensor * b_pad = ggml_new_tensor(ctx, type, 4, ne_b_pad.data());
+ ggml_set_param(b_pad);
+ ggml_set_name(b_pad, "b_pad");
+ // View that skips the first row, so b has a non-zero byte offset
+ b = ggml_view_4d(ctx, b_pad,
+ ne_b[0], ne_b[1], ne_b[2], ne_b[3],
+ b_pad->nb[1], b_pad->nb[2], b_pad->nb[3],
+ b_pad->nb[1]);
+ } else {
+ b = ggml_new_tensor(ctx, type, 4, ne_b.data());
+ ggml_set_param(b);
+ }
ggml_set_name(b, "b");
- ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
+ // When ne_b[0] < ne_a[0], a->nb[1] != b->nb[1], so the stride
+ // parameters to ggml_acc don't match b's natural stride.
+ ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(out, "out");
return out;
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
- test_cases.emplace_back(new test_acc());
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1));
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1));
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1));
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1));
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular
test_cases.emplace_back(new test_pad_ext());
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate
+ // acc
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1));
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1));
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1));
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1));
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
+ test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
+
return test_cases;
}