]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : use a single kernel for CUDA mul op (#373)
authorJiahao Li <redacted>
Tue, 11 Jul 2023 18:12:57 +0000 (02:12 +0800)
committerGitHub <redacted>
Tue, 11 Jul 2023 18:12:57 +0000 (21:12 +0300)
src/ggml-cuda.cu

index 17c8bb76137e68d9b4611126f61b585992bff775..2fb30c6e6087ce560c56cc3b262eb36bb7ff2df2 100644 (file)
@@ -2305,20 +2305,11 @@ inline void ggml_cuda_op_mul(
     GGML_ASSERT(dst_ddf_i != nullptr);
 
     const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
 
     const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-
-    for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
-        const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
 
-        float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
-        float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
-        float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
-
-        // compute
-        mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
-    }
+    mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
 
     (void) dst;
     (void) src0_ddq_i;