]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Only one CUDA stream per device for async compute (#1898)
authorJohannes Gäßler <redacted>
Sat, 17 Jun 2023 17:15:02 +0000 (19:15 +0200)
committerGitHub <redacted>
Sat, 17 Jun 2023 17:15:02 +0000 (19:15 +0200)
README.md
examples/common.cpp
ggml-cuda.cu

index b9759b00b51f7230e228dd38cd7c23bc4941e748..7defb7584db11fd1b799dd61a9f327fd227cde96 100644 (file)
--- a/README.md
+++ b/README.md
@@ -336,7 +336,6 @@ Building the program with BLAS support may lead to some performance improvements
     cmake .. -DLLAMA_CUBLAS=ON
     cmake --build . --config Release
     ```
-  Note: Because llama.cpp uses multiple CUDA streams for matrix multiplication results [are not guaranteed to be reproducible](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility). If you need reproducibility, set `GGML_CUDA_MAX_STREAMS` in the file `ggml-cuda.cu` to 1.
 
   The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used.
 
index 055383beff961dc1677ab99ad5231f4dc0c57a03..fed24e027d8a814146defdb07f02837e172c3664 100644 (file)
@@ -106,9 +106,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         }
 
         if (arg == "-s" || arg == "--seed") {
-#if defined(GGML_USE_CUBLAS)
-            fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
-#endif
             if (++i >= argc) {
                 invalid_param = true;
                 break;
index fed2a7ce1095c3cd4fd489088aa75255d9c49b41..16488b9f9067f9460b65e84d440a1fb21854e087 100644 (file)
@@ -1467,19 +1467,13 @@ static void * g_scratch_buffer = nullptr;
 static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
 static size_t g_scratch_offset = 0;
 
-#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
-#define GGML_CUDA_MAX_EVENTS 64
-
 static int g_device_count = -1;
 static int g_main_device = 0;
 static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
-static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
-
-static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr };
+static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
 
 void ggml_init_cublas() {
     static bool initialized = false;
@@ -1503,15 +1497,8 @@ void ggml_init_cublas() {
         for (int id = 0; id < g_device_count; ++id) {
             CUDA_CHECK(cudaSetDevice(id));
 
-            // create streams
-            for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
-                CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking));
-                CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking));
-            }
-            // create events
-            for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
-                CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
-            }
+            // create main stream
+            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
 
             // create cublas handle
             CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
@@ -1978,6 +1965,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
     size_t  dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
 
+    // if multiple GPUs are used they need to wait for the main GPU to finish
+    if (split && g_device_count > 1) {
+        CUDA_CHECK(cudaSetDevice(g_main_device));
+        CUDA_CHECK(cudaDeviceSynchronize());
+    }
+
     for (int id = 0; id < g_device_count; ++id) {
         if (!split && id != g_main_device) {
             continue;
@@ -2076,9 +2069,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                 }
                 const int64_t i11 = i13*ne12 + i12;
 
-                cudaStream_t cudaStream_main        =        g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
-                cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
-                cudaEvent_t  cudaEvent_memcpy_src1  =  g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
+                cudaStream_t cudaStream_main = g_cudaStreams_main[id];
 
                 // for split tensors the data begins at i0 == i0_offset_low
                 char  * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
@@ -2106,14 +2097,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                     if (src1->backend == GGML_BACKEND_CPU) {
                         GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
                         int64_t nrows1 = flatten_rows ? nrows0 : ne11;
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
                     } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
                         if (id != g_main_device) {
                             GGML_ASSERT(!flatten_rows);
                             float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
                             src1_ddf_i_source += i11*src1_stride;
                             CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
-                                                    cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
+                                                    cudaMemcpyDeviceToDevice, cudaStream_main));
                         }
                     } else if (src1_on_device && !src1_is_contiguous) {
                         GGML_ASSERT(!split);
@@ -2122,7 +2113,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                         GGML_ASSERT(false);
                     }
                 }
-                CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
 
                 if (!src0_on_device || !src0_is_contiguous) {
                     if (src0_is_f32) {
@@ -2138,9 +2128,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                     CUDA_CHECK(cudaGetLastError());
                 }
 
-                // wait with main stream until src1 memcpy is done
-                CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
-
                 // do the computation
                 op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
 
@@ -2178,8 +2165,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
 
     // wait until each device is finished, then free their buffers
     for (int id = 0; id < g_device_count; ++id) {
+        if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
+            continue;
+        }
+
         CUDA_CHECK(cudaSetDevice(id));
         CUDA_CHECK(cudaDeviceSynchronize());
+
         if (src0_asq[id] > 0) {
             ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
         }
@@ -2245,7 +2237,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
     const int64_t ne02 = src0->ne[2];
 
     CUDA_CHECK(cudaSetDevice(g_main_device));
-    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
+    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
 
     struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2257,8 +2249,6 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
     ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
-
-    CUDA_CHECK(cudaDeviceSynchronize());
 }
 
 void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -2276,7 +2266,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     const int64_t nb02 = src0->nb[2];
 
     CUDA_CHECK(cudaSetDevice(g_main_device));
-    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
+    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
 
     struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2291,8 +2281,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     const int channel_stride_x = nb02 / sizeof(half);
 
     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
-
-    CUDA_CHECK(cudaDeviceSynchronize());
 }
 
 void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2348,7 +2336,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
     const int64_t nb12 = src1->nb[2];
 
     CUDA_CHECK(cudaSetDevice(g_main_device));
-    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
+    cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
 
     const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
@@ -2366,8 +2354,6 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
         GGML_ASSERT(false);
     }
 
-    CUDA_CHECK(cudaDeviceSynchronize());
-
     (void) dst;
 }