template <typename T>
static __global__ void k_repeat_back(
- const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
- const int64_t ne0, const int64_t ne1, const int64_t ne2) {
+ const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
- const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
- const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
- const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
+ const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
+ const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
+ const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
+ const int64_t tid2 = tid23 % ne2;
+ const int64_t tid3 = tid23 / ne2;
if (tid0 >= ne0) {
return;
}
T sum = 0;
- for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
- for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
- for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
- sum += src[i2*ne01*ne00 + i1*ne00 + i0];
+ for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
+ for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
+ for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
+ for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
+ sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
+ }
}
}
}
- dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
+ dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
template<float (*bin_op)(const float, const float)>
template <typename T>
static void repeat_back_cuda(
- const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
- const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
+ const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
- const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
- k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
+ const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
+ k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
+ (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
}
template<class op>
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == dst->type);
- GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_can_repeat(dst, src0));
cudaStream_t stream = ctx.stream();
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
- GGML_ASSERT(src0->ne[3] == 1);
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ GGML_ASSERT(ne2*ne3 <= (1 << 15));
- const int64_t ne0 = dst->ne[0];
- const int64_t ne1 = dst->ne[1];
- const int64_t ne2 = dst->ne[2];
- GGML_ASSERT(dst->ne[3] == 1);
+ const size_t ts = ggml_type_size(src0->type);
+ const size_t s00 = nb00 / ts;
+ const size_t s01 = nb01 / ts;
+ const size_t s02 = nb02 / ts;
+ const size_t s03 = nb03 / ts;
switch (dst->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
- repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
+ repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
} break;
default: {
GGML_ASSERT(false);
} break;
case GGML_OP_MUL: {
if (src0_needs_grads) {
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
}
if (src1_needs_grads) {
struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
// src1.shape [n,p,qq,rr]
if (src0_needs_grads) {
- struct ggml_tensor * s1_tg =
+ GGML_ASSERT(grad->ne[2] == src1->ne[2]);
+ GGML_ASSERT(grad->ne[3] == src1->ne[3]);
+ struct ggml_tensor * tmp =
ggml_out_prod(ctx, // [n,m,qq,rr]
src1, // [n,p,qq,rr]
grad); // [m,p,qq,rr]
- const int64_t qq = s1_tg->ne[2];
- const int64_t rr = s1_tg->ne[3];
- const int64_t q1 = src0->ne[2];
- const int64_t r1 = src0->ne[3];
- const bool ne2_broadcasted = qq > q1;
- const bool ne3_broadcasted = rr > r1;
- if (ne2_broadcasted || ne3_broadcasted) {
- // sum broadcast repetitions of s1_tg into shape of src0
- s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
+ if (!ggml_are_same_shape(tmp, src0)) {
+ GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
+ GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
+ GGML_ASSERT(tmp->ne[3] == 1);
+
+ const int64_t nr2 = tmp->ne[2] / src0->ne[2];
+ const size_t nb2 = tmp->nb[2] * nr2;
+ const size_t nb3 = tmp->nb[2];
+
+ tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
+ tmp = ggml_repeat_back(ctx, tmp, src0);
}
- ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
+ ggml_add_or_set(ctx, cgraph, isrc0, tmp);
}
if (src1_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc1,
if (src0_needs_grads) {
GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
GGML_ASSERT(ggml_is_contiguous(grad));
- ggml_add_or_set(ctx, cgraph, isrc0, grad);
+ GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
+ ggml_add_or_set(ctx, cgraph, isrc0,
+ ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
}
} break;
case GGML_OP_RESHAPE: {
}
};
+// GGML_OP_REPEAT_BACK
+struct test_repeat_back : public test_case {
+ const ggml_type type;
+ const std::array<int64_t, 4> ne;
+ const std::array<int, 4> nr;
+ const bool v; // whether src is a noncontiguous view
+
+ std::string vars() override {
+ return VARS_TO_STR4(type, ne, nr, v);
+ }
+
+ size_t op_size(ggml_tensor * t) override {
+ return ggml_nbytes(t) * 2;
+ }
+
+ test_repeat_back(ggml_type type = GGML_TYPE_F32,
+ std::array<int64_t, 4> ne = {8, 6, 4, 2},
+ std::array<int, 4> nr = {2, 2, 2, 2},
+ bool v = false)
+ : type(type), ne(ne), nr(nr), v(v) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+ ggml_set_name(src, "src");
+
+ if (v) {
+ GGML_ASSERT(ne[0] % 2 == 0);
+ GGML_ASSERT(ne[1] % 2 == 0);
+ GGML_ASSERT(ne[2] % 2 == 0);
+ GGML_ASSERT(ne[3] % 2 == 0);
+ GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);
+ GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);
+ GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);
+ GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);
+
+ const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;
+ const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;
+ const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;
+ const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;
+
+ src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);
+ }
+
+ ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_set_name(target, "target");
+
+ ggml_tensor * out = ggml_repeat_back(ctx, src, target);
+ ggml_set_name(out, "out");
+
+ return out;
+ }
+};
+
// GGML_OP_DUP
struct test_dup : public test_case {
const ggml_type type;
return 5e-4;
}
+ int64_t grad_nmax() override {
+ return 20000;
+ }
+
uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
- ggml_set_param(ctx, a);
- ggml_set_param(ctx, b);
+ if (!ggml_is_quantized(type_a)) {
+ if (bs[1] == 1 && nr[1] == 1) {
+ ggml_set_param(ctx, a);
+ }
+ ggml_set_param(ctx, b);
+ }
ggml_set_name(a, "a");
ggml_set_name(b, "b");
} else {
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
- ggml_set_param(ctx, a);
- ggml_set_param(ctx, b);
+ if (!ggml_is_quantized(type_a)) {
+ if (bs[1] == 1 && nr[1] == 1) {
+ ggml_set_param(ctx, a);
+ }
+ ggml_set_param(ctx, b);
+ }
ggml_set_name(a, "a");
ggml_set_name(b, "b");
}
test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
}
+ for (bool view : {false, true}) {
+ test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));
+ test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
+ test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
+ test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
+ test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
+ test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
+ test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
+ }
+
test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
// test cases without permutation
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
-
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
- test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2}));
+
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
// test cases with permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));