]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Better CUDA synchronization logic (#2057)
authorJohannes Gäßler <redacted>
Sat, 1 Jul 2023 19:49:44 +0000 (21:49 +0200)
committerGitHub <redacted>
Sat, 1 Jul 2023 19:49:44 +0000 (21:49 +0200)
ggml-cuda.cu
ggml-cuda.h

index 4e0d3dbdea4d499862ccc654461faefa3ebac85e..50df20edd7a7b211324e0af1c18e970bc250251d 100644 (file)
@@ -214,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
 #endif
 
+struct ggml_tensor_extra_gpu {
+    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
+    cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
+};
+
 static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -1970,7 +1975,6 @@ inline void ggml_cuda_op_add(
     } else {
         GGML_ASSERT(false);
     }
-    CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
@@ -2002,7 +2006,6 @@ inline void ggml_cuda_op_mul(
 
         // compute
         mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
-        CUDA_CHECK(cudaGetLastError());
     }
 
     (void) dst;
@@ -2023,7 +2026,6 @@ inline void ggml_cuda_op_silu(
 
     // compute
     silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
@@ -2046,7 +2048,6 @@ inline void ggml_cuda_op_rms_norm(
 
     // compute
     rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
@@ -2125,7 +2126,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
             GGML_ASSERT(false);
             break;
     }
-    CUDA_CHECK(cudaGetLastError());
 
 #ifdef GGML_CUDA_DMMV_F16
     if (src1_convert_f16) {
@@ -2202,7 +2202,6 @@ inline void ggml_cuda_op_rope(
 
     // compute
     rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) dst;
     (void) src0_ddq_i;
@@ -2226,7 +2225,6 @@ inline void ggml_cuda_op_diag_mask_inf(
 
     // compute
     diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) dst;
     (void) src0_ddq_i;
@@ -2248,7 +2246,6 @@ inline void ggml_cuda_op_soft_max(
 
     // compute
     soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
@@ -2344,10 +2341,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
     size_t  dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
 
-    // if multiple GPUs are used they need to wait for the main GPU to finish
+    // if multiple devices are used they need to wait for the main device
+    // here an event is recorded that signifies that the main device has finished calculating the input data
     if (split && g_device_count > 1) {
         CUDA_CHECK(cudaSetDevice(g_main_device));
-        CUDA_CHECK(cudaDeviceSynchronize());
+        CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
     }
 
     for (int id = 0; id < g_device_count; ++id) {
@@ -2373,6 +2371,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
         int64_t row_diff = row_high - row_low;
 
         cudaSetDevice(id);
+        cudaStream_t cudaStream_main = g_cudaStreams_main[id];
+
+        // wait for main GPU data if necessary
+        if (split && id != g_main_device) {
+            CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
+        }
 
         if (src0_on_device && src0_is_contiguous) {
             if (src0_is_f32) {
@@ -2448,8 +2452,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                 }
                 const int64_t i11 = i13*ne12 + i12;
 
-                cudaStream_t cudaStream_main = g_cudaStreams_main[id];
-
                 // for split tensors the data begins at i0 == i0_offset_low
                 char  * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
                 float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
@@ -2509,6 +2511,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
 
                 // do the computation
                 op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
+                CUDA_CHECK(cudaGetLastError());
 
                 // copy dst to host or other device if necessary
                 if (!dst_on_device) {
@@ -2538,6 +2541,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                         CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
                     }
                 }
+
+                // signify to main device that other device is done
+                if (split && g_device_count > 1 && id != g_main_device) {
+                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
+                }
             }
         }
     }
@@ -2549,7 +2557,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
         }
 
         CUDA_CHECK(cudaSetDevice(id));
-        CUDA_CHECK(cudaDeviceSynchronize());
 
         if (src0_asq[id] > 0) {
             ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
@@ -2564,6 +2571,21 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
             ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
         }
     }
+
+    // main device waits for all other devices to be finished
+    if (split && g_device_count > 1) {
+        CUDA_CHECK(cudaSetDevice(g_main_device));
+        for (int id = 0; id < g_device_count; ++id) {
+            if (id != g_main_device) {
+                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
+            }
+        }
+    }
+
+    if (dst->backend == GGML_BACKEND_CPU) {
+        CUDA_CHECK(cudaSetDevice(g_main_device));
+        CUDA_CHECK(cudaDeviceSynchronize());
+    }
 }
 
 void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2803,6 +2825,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
 
         extra->data_device[id] = buf;
+
+        if (backend == GGML_BACKEND_GPU_SPLIT) {
+            CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
+        }
     }
 
     tensor->extra = extra;
@@ -2816,12 +2842,15 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
     for (int id = 0; id < g_device_count; ++id) {
-        if (extra->data_device[id] == nullptr) {
-            continue;
+        if (extra->data_device[id] != nullptr) {
+            CUDA_CHECK(cudaSetDevice(id));
+            CUDA_CHECK(cudaFree(extra->data_device[id]));
         }
 
-        CUDA_CHECK(cudaSetDevice(id));
-        CUDA_CHECK(cudaFree(extra->data_device[id]));
+        if (extra->events[id] != nullptr) {
+            CUDA_CHECK(cudaSetDevice(id));
+            CUDA_CHECK(cudaEventDestroy(extra->events[id]));
+        }
     }
 
     delete extra;
index 7a65a3558a074d5d9d4052e2e2fdd1075d40a168..3c1e8deb6a6ddbfbb43e22bd01cc64e560fe313d 100644 (file)
@@ -8,10 +8,6 @@ extern "C" {
 
 #define GGML_CUDA_MAX_DEVICES       16
 
-struct ggml_tensor_extra_gpu {
-    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
-};
-
 void   ggml_init_cublas(void);
 void   ggml_cuda_set_tensor_split(const float * tensor_split);