]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : use CUDA memory pool with async memory allocation/deallocation when available...
authorOleksii Maryshchenko <redacted>
Thu, 2 Nov 2023 17:10:39 +0000 (18:10 +0100)
committerGitHub <redacted>
Thu, 2 Nov 2023 17:10:39 +0000 (19:10 +0200)
* Using cuda memory pools for async alloc/dealloc.

* If cuda device doesnt support memory pool than use old implementation.

* Removed redundant cublasSetStream

---------

Co-authored-by: Oleksii Maryshchenko <redacted>
ggml-cuda.cu

index e4629512611b6c8effc34d17ad3f4bf77c367a08..58b58f3315428618d7d32c7389735162adf73c8e 100644 (file)
@@ -181,11 +181,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cudaError_t err_ = (err);                                                       \
         if (err_ != cudaSuccess) {                                                      \
-            int id;                                                                     \
-            cudaGetDevice(&id);                                                         \
+            int dev_id;                                                                     \
+            cudaGetDevice(&dev_id);                                                         \
             fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
                 cudaGetErrorString(err_));                                              \
-            fprintf(stderr, "current device: %d\n", id);                                \
+            fprintf(stderr, "current device: %d\n", dev_id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
@@ -195,11 +195,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
     do {                                                                                \
         cublasStatus_t err_ = (err);                                                    \
         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            int id;                                                                     \
-            cudaGetDevice(&id);                                                         \
+            int dev_id;                                                                     \
+            cudaGetDevice(&dev_id);                                                         \
             fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
                     err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
-            fprintf(stderr, "current device: %d\n", id);                                \
+            fprintf(stderr, "current device: %d\n", dev_id);                                \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
@@ -465,6 +465,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
 
 #define MAX_STREAMS 8
 static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
+static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
 
 struct ggml_tensor_extra_gpu {
     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -5772,6 +5773,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     return ptr;
 }
 
+static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
+    if (g_cudaMemPools[id] == nullptr) {
+        return ggml_cuda_pool_malloc(size, actual_size);
+    }
+    void *ptr;
+    CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
+    *actual_size = size;
+    return ptr;
+}
+
 static void ggml_cuda_pool_free(void * ptr, size_t size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
     int id;
@@ -5790,6 +5801,13 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
 }
 
 
+static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
+    if (g_cudaMemPools[id] == nullptr) {
+        return ggml_cuda_pool_free(ptr, actual_size);
+    }
+    CUDA_CHECK(cudaFreeAsync(ptr, stream));
+}
+
 void ggml_init_cublas() {
     static bool initialized = false;
 
@@ -5844,6 +5862,13 @@ void ggml_init_cublas() {
             // create cublas handle
             CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
             CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
+
+            // configure memory pool
+            cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
+            if (err == cudaSuccess) {
+                size_t treshold = UINT64_MAX;
+                CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
+            }
         }
 
         // configure logging to stdout
@@ -6437,7 +6462,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             size_t ne = row_diff*ne00;
-            src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
+            src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
             to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
         }
         const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@@ -6448,13 +6473,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
             size_t ne = src1_ncols*ne10;
-            src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
+            src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
         }
         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
-
-        size_t dst_as = 0;
-        half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
+        size_t dst_f16_as = 0;
+        half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
 
         const half alpha_f16 = 1.0f;
         const half beta_f16 = 0.0f;
@@ -6472,14 +6496,15 @@ inline void ggml_cuda_op_mul_mat_cublas(
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
         to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
 
-        ggml_cuda_pool_free(dst_f16, dst_as);
+        if (dst_f16_as != 0) {
+            ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
+        }
 
         if (src0_as != 0) {
-            ggml_cuda_pool_free(src0_as_f16, src0_as);
+            ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
         }
-
         if (src1_as != 0) {
-            ggml_cuda_pool_free(src1_as_f16, src1_as);
+            ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
         }
     }
     else {
@@ -6489,7 +6514,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
         if (src0->type != GGML_TYPE_F32) {
             const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
             GGML_ASSERT(to_fp32_cuda != nullptr);
-            src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
+            src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
             to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
         }
         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@@ -6506,7 +6531,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
                     &beta,  dst_dd_i,   ldc));
 
         if (src0_as != 0) {
-            ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
+            ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
         }
     }
 
@@ -6929,21 +6954,22 @@ static void ggml_cuda_op_mul_mat(
             src0_dd[id] = (char *) src0_extra->data_device[id];
         } else {
             const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
-            src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
+            src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
         }
 
         if (src1_on_device && src1_is_contiguous) {
             src1_ddf[id] = (float *) src1_extra->data_device[id];
         } else {
-            src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
+            src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
         }
 
         if (convert_src1_to_q8_1) {
-            src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
+            const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
+            src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
 
             if (src1_on_device && src1_is_contiguous) {
                 quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
-                CUDA_CHECK(cudaGetLastError());
+                // CUDA_CHECK(cudaGetLastError());
             }
         }
 
@@ -6951,7 +6977,7 @@ static void ggml_cuda_op_mul_mat(
             dst_dd[id] = (float *) dst_extra->data_device[id];
         } else {
             const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
-            dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
+            dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id,  stream);
         }
     }
 
@@ -7077,24 +7103,6 @@ static void ggml_cuda_op_mul_mat(
         }
     }
 
-    for (int64_t id = 0; id < g_device_count; ++id) {
-        CUDA_CHECK(ggml_cuda_set_device(id));
-
-        // free buffers again when done
-        if (src0_as[id] > 0) {
-            ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
-        }
-        if (src1_asf[id] > 0) {
-            ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
-        }
-        if (src1_asq[id] > 0) {
-            ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
-        }
-        if (dst_as[id] > 0) {
-            ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
-        }
-    }
-
     // main device waits for all other devices to be finished
     if (split && g_device_count > 1) {
         int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
@@ -7112,6 +7120,21 @@ static void ggml_cuda_op_mul_mat(
         CUDA_CHECK(ggml_cuda_set_device(g_main_device));
         CUDA_CHECK(cudaDeviceSynchronize());
     }
+
+    for (int64_t id = 0; id < g_device_count; ++id) {
+        if (src0_as[id] > 0) {
+            ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
+        }
+        if (src1_asf[id] > 0) {
+            ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
+        }
+        if (src1_asq[id] > 0) {
+            ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
+        }
+        if (dst_as[id] > 0) {
+            ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
+        }
+    }
 }
 
 static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7298,11 +7321,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     GGML_ASSERT(to_fp16_cuda != nullptr);
 
     size_t src1_as = 0;
-    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
+    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
     to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
 
     size_t dst_as = 0;
-    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+    half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
 
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
@@ -7349,10 +7372,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     } else {
         // use cublasGemmBatchedEx
         const int ne23 = ne12*ne13;
-
-        void ** ptrs_as = nullptr;
+        // allocate device memory for pointers
         size_t ptrs_s = 0;
-        ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
+        void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), &ptrs_s, id, main_stream);
 
         dim3 block_dims(ne13, ne12);
         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
@@ -7365,7 +7387,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
                 dst->nb[2], dst->nb[3],
                 r2, r3);
         CUDA_CHECK(cudaGetLastError());
-
         CUBLAS_CHECK(
         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
@@ -7375,16 +7396,21 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
                 ne23,
                 CUBLAS_COMPUTE_16F,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-
-        ggml_cuda_pool_free(ptrs_as, ptrs_s);
+        // free device memory for pointers
+        if (ptrs_s != 0) {
+            ggml_cuda_pool_free_async(ptrs_as, ptrs_s, id, main_stream);
+        }
     }
 #endif
 
     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
     to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
-
-    ggml_cuda_pool_free(src1_as_f16, src1_as);
-    ggml_cuda_pool_free(dst_f16, dst_as);
+    if (src1_as != 0) {
+        ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
+    }
+    if (dst_as != 0) {
+        ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
+    }
 }
 
 static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {