From: Jiahao Li Date: Tue, 11 Jul 2023 18:06:05 +0000 (+0800) Subject: ggml : broadcast ggml_add() for F32 (#359) X-Git-Tag: upstream/0.0.1642~1334 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=649a922dc727f246c1280e49f2dbca6ed26d66c3;p=pkg%2Fggml%2Fsources%2Fggml ggml : broadcast ggml_add() for F32 (#359) * Support broadcast add for fp32 * Use single kernel for broadcast add --- diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index f9b55173..17c8bb76 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -239,13 +239,13 @@ struct ggml_tensor_extra_gpu { cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs }; -static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { +static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; - if (i >= k) { + if (i >= kx) { return; } - dst[i] = x[i] + y[i]; + dst[i] = x[i] + y[i%ky]; } static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { @@ -1718,9 +1718,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale dst[i] = scale * x[i]; } -static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f32<<>>(x, y, dst, k); +static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { + const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; + add_f32<<>>(x, y, dst, kx, ky); } static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { @@ -2274,14 +2274,16 @@ inline void ggml_cuda_op_add( GGML_ASSERT(src1_ddf_i != nullptr); GGML_ASSERT(dst_ddf_i != nullptr); - const int64_t ne0 = src0->ne[0]; + const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; + const int64_t ne10 = src1->ne[0]; + // compute if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main); + add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main); + add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main); } else { GGML_ASSERT(false); } diff --git a/src/ggml.c b/src/ggml.c index 6419cc8b..d84afbb6 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -5035,11 +5035,15 @@ struct ggml_tensor * ggml_add_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); + // TODO: support less-strict constraint + // GGML_ASSERT(ggml_can_repeat(b, a)); + GGML_ASSERT(ggml_can_repeat_rows(b, a)); bool is_node = false; - if (a->grad || b->grad) { + if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); is_node = true; } @@ -8297,7 +8301,7 @@ static void ggml_compute_forward_add_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -8322,23 +8326,23 @@ static void ggml_compute_forward_add_f32( if (nb10 == sizeof(float)) { for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); #ifdef GGML_USE_ACCELERATE - vDSP_vadd( - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); + vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); #else - ggml_vec_add_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); #endif // } // } @@ -8346,15 +8350,20 @@ static void ggml_compute_forward_add_f32( } else { // src1 is not contiguous for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; }