]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : broadcast ggml_add() for F32 (#359)
authorJiahao Li <redacted>
Tue, 11 Jul 2023 18:06:05 +0000 (02:06 +0800)
committerGitHub <redacted>
Tue, 11 Jul 2023 18:06:05 +0000 (21:06 +0300)
* Support broadcast add for fp32

* Use single kernel for broadcast add

src/ggml-cuda.cu
src/ggml.c

index f9b55173b9a2cbe812086c0f413eec37d9e3f421..17c8bb76137e68d9b4611126f61b585992bff775 100644 (file)
@@ -239,13 +239,13 @@ struct ggml_tensor_extra_gpu {
     cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
 };
 
-static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
+static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i >= k) {
+    if (i >= kx) {
         return;
     }
-    dst[i] = x[i] + y[i];
+    dst[i] = x[i] + y[i%ky];
 }
 
 static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
@@ -1718,9 +1718,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
     dst[i] = scale * x[i];
 }
 
-static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
-    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
+    const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
 }
 
 static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@@ -2274,14 +2274,16 @@ inline void ggml_cuda_op_add(
     GGML_ASSERT(src1_ddf_i != nullptr);
     GGML_ASSERT(dst_ddf_i != nullptr);
 
-    const int64_t ne0 = src0->ne[0];
+    const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
 
+    const int64_t ne10 = src1->ne[0];
+
     // compute
     if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-        add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
+        add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
     } else {
         GGML_ASSERT(false);
     }
index 6419cc8b604ea42c12fe4a1d3ccdc422838f07ec..d84afbb6d423160fa45dfe0a74db2108e48b58dc 100644 (file)
@@ -5035,11 +5035,15 @@ struct ggml_tensor * ggml_add_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
+    // TODO: support less-strict constraint
+    //       GGML_ASSERT(ggml_can_repeat(b, a));
+    GGML_ASSERT(ggml_can_repeat_rows(b, a));
 
     bool is_node = false;
 
-    if (a->grad || b->grad) {
+    if (!inplace && (a->grad || b->grad)) {
+        // TODO: support backward pass for broadcasting
+        GGML_ASSERT(ggml_are_same_shape(a, b));
         is_node = true;
     }
 
@@ -8297,7 +8301,7 @@ static void ggml_compute_forward_add_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -8322,23 +8326,23 @@ static void ggml_compute_forward_add_f32(
 
     if (nb10 == sizeof(float)) {
         for (int ir = ir0; ir < ir1; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
 #ifdef GGML_USE_ACCELERATE
-            vDSP_vadd(
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                    ne0);
+            vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
 #else
-            ggml_vec_add_f32(ne0,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+            ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
 #endif
                 // }
             // }
@@ -8346,15 +8350,20 @@ static void ggml_compute_forward_add_f32(
     } else {
         // src1 is not contiguous
         for (int ir = ir0; ir < ir1; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
 
-            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
             for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
             }