]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : support broadcast add & mul (#2192)
authorJiahao Li <redacted>
Fri, 14 Jul 2023 18:38:24 +0000 (02:38 +0800)
committerGitHub <redacted>
Fri, 14 Jul 2023 18:38:24 +0000 (21:38 +0300)
Co-authored-by: Georgi Gerganov <redacted>
ggml-cuda.cu

index 4c9e21429e10f4ead6bdbb6261625b8cc9d84be0..73cfc55e54d05da3a0a8ce6a856b4b1fe12266ba 100644 (file)
@@ -252,13 +252,13 @@ struct ggml_tensor_extra_gpu {
     cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
 };
 
-static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
+static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i >= k) {
+    if (i >= kx) {
         return;
     }
-    dst[i] = x[i] + y[i];
+    dst[i] = x[i] + y[i%ky];
 }
 
 static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
@@ -1996,9 +1996,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
     dst[i] = scale * x[i];
 }
 
-static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
-    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
+    const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
 }
 
 static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@@ -2610,17 +2610,15 @@ inline void ggml_cuda_op_add(
     GGML_ASSERT(src1_ddf_i != nullptr);
     GGML_ASSERT(dst_ddf_i  != nullptr);
 
-    // TODO: support broadcasting
-    GGML_ASSERT(ggml_nelements(src0) == ggml_nelements(src1));
-
     const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
 
-    // const int64_t ne10 = src1->ne[0];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
 
     // compute
     if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
         add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
     } else {
@@ -2644,19 +2642,12 @@ inline void ggml_cuda_op_mul(
     GGML_ASSERT(dst_ddf_i  != nullptr);
 
     const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
 
-    for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
-        const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
-
-        float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
-        float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
-        float * dst_ddf_i01  = dst_ddf_i + i01*ne00;
-
-        // compute
-        mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
-    }
+    mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
 
     (void) dst;
     (void) src0_ddq_i;