template <typename T>
static __global__ void k_repeat_back(
- const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
- const int64_t ne0, const int64_t ne1, const int64_t ne2) {
+ const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
- const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
- const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
- const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
+ const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
+ const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
+ const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
+ const int64_t tid2 = tid23 % ne2;
+ const int64_t tid3 = tid23 / ne2;
if (tid0 >= ne0) {
return;
}
T sum = 0;
- for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
- for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
- for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
- sum += src[i2*ne01*ne00 + i1*ne00 + i0];
+ for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
+ for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
+ for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
+ for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
+ sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
+ }
}
}
}
- dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
+ dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
template<float (*bin_op)(const float, const float)>
template <typename T>
static void repeat_back_cuda(
- const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
- const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
+ const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
- const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
- k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
+ const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
+ k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
+ (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
}
template<class op>
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == dst->type);
- GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_can_repeat(dst, src0));
cudaStream_t stream = ctx.stream();
- const int64_t ne00 = src0->ne[0];
- const int64_t ne01 = src0->ne[1];
- const int64_t ne02 = src0->ne[2];
- GGML_ASSERT(src0->ne[3] == 1);
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ GGML_ASSERT(ne2*ne3 <= (1 << 15));
- const int64_t ne0 = dst->ne[0];
- const int64_t ne1 = dst->ne[1];
- const int64_t ne2 = dst->ne[2];
- GGML_ASSERT(dst->ne[3] == 1);
+ const size_t ts = ggml_type_size(src0->type);
+ const size_t s00 = nb00 / ts;
+ const size_t s01 = nb01 / ts;
+ const size_t s02 = nb02 / ts;
+ const size_t s03 = nb03 / ts;
switch (dst->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
- repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
+ repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
} break;
default: {
GGML_ASSERT(false);
} break;
case GGML_OP_MUL: {
if (src0_needs_grads) {
- ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
}
if (src1_needs_grads) {
struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
// src1.shape [n,p,qq,rr]
if (src0_needs_grads) {
- struct ggml_tensor * s1_tg =
+ GGML_ASSERT(grad->ne[2] == src1->ne[2]);
+ GGML_ASSERT(grad->ne[3] == src1->ne[3]);
+ struct ggml_tensor * tmp =
ggml_out_prod(ctx, // [n,m,qq,rr]
src1, // [n,p,qq,rr]
grad); // [m,p,qq,rr]
- const int64_t qq = s1_tg->ne[2];
- const int64_t rr = s1_tg->ne[3];
- const int64_t q1 = src0->ne[2];
- const int64_t r1 = src0->ne[3];
- const bool ne2_broadcasted = qq > q1;
- const bool ne3_broadcasted = rr > r1;
- if (ne2_broadcasted || ne3_broadcasted) {
- // sum broadcast repetitions of s1_tg into shape of src0
- s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
+ if (!ggml_are_same_shape(tmp, src0)) {
+ GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
+ GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
+ GGML_ASSERT(tmp->ne[3] == 1);
+
+ const int64_t nr2 = tmp->ne[2] / src0->ne[2];
+ const size_t nb2 = tmp->nb[2] * nr2;
+ const size_t nb3 = tmp->nb[2];
+
+ tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
+ tmp = ggml_repeat_back(ctx, tmp, src0);
}
- ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
+ ggml_add_or_set(ctx, cgraph, isrc0, tmp);
}
if (src1_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc1,
if (src0_needs_grads) {
GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
GGML_ASSERT(ggml_is_contiguous(grad));
- ggml_add_or_set(ctx, cgraph, isrc0, grad);
+ GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
+ ggml_add_or_set(ctx, cgraph, isrc0,
+ ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
}
} break;
case GGML_OP_RESHAPE: {