]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuBLAS: use host pinned memory and dequantize while copying (#1207)
authorslaren <redacted>
Sat, 29 Apr 2023 00:04:18 +0000 (02:04 +0200)
committerGitHub <redacted>
Sat, 29 Apr 2023 00:04:18 +0000 (02:04 +0200)
* cuBLAS: dequantize simultaneously while copying memory

* cuBLAS: use host pinned memory

* cuBLAS: improve ggml_compute_forward_mul_mat_f16_f32 with pinned memory

* cuBLAS: also pin kv cache

* fix rebase

Makefile
ggml-cuda.cu
ggml-cuda.h
ggml.c
llama.cpp
llama_util.h

index 0715e857bc34663da2bf0e7d40a71fb8041f9ac0..5a1cb3e83e365fdf65fd408b921c62e3446a93d1 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -106,6 +106,7 @@ ifdef LLAMA_OPENBLAS
 endif
 ifdef LLAMA_CUBLAS
        CFLAGS    += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
+       CXXFLAGS  += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
        LDFLAGS   += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
        OBJS      += ggml-cuda.o
        NVCC      = nvcc
@@ -164,10 +165,10 @@ $(info )
 # Build library
 #
 
-ggml.o: ggml.c ggml.h
+ggml.o: ggml.c ggml.h ggml-cuda.h
        $(CC)  $(CFLAGS)   -c $< -o $@
 
-llama.o: llama.cpp ggml.h llama.h llama_util.h
+llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama_util.h
        $(CXX) $(CXXFLAGS) -c $< -o $@
 
 common.o: examples/common.cpp examples/common.h
index eb244f409aafdc8f9916cb838fadda89b0e10e50..5a2701cfeef68696730b7c3c067fb41a60ad221d 100644 (file)
@@ -227,6 +227,25 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st
     dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
 }
 
+dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q4_2:
+            return dequantize_row_q4_2_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_row_q5_0_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_row_q5_1_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_row_q8_0_cuda;
+        default:
+            return nullptr;
+    }
+}
+
 // buffer pool for cuda
 #define MAX_CUDA_BUFFERS 16
 
@@ -286,18 +305,22 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
     CUDA_CHECK(cudaFree(ptr));
 }
 
-cublasHandle_t g_cublasH = NULL;
-cudaStream_t g_cudaStream = NULL;
+cublasHandle_t g_cublasH = nullptr;
+cudaStream_t g_cudaStream = nullptr;
+cudaStream_t g_cudaStream2 = nullptr;
+cudaEvent_t g_cudaEvent = nullptr;
 
-void ggml_init_cublas(void) {
-    if (g_cublasH == NULL) {
+void ggml_init_cublas() {
+    if (g_cublasH == nullptr) {
         // create cublas handle, bind a stream
         CUBLAS_CHECK(cublasCreate(&g_cublasH));
-
         CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
-
         CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
 
+        // create additional stream and event for synchronization
+        CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
+        CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
+
         // configure logging to stdout
         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
     }
@@ -330,3 +353,13 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src,
         return cudaSuccess;
     }
 }
+
+void * ggml_cuda_host_malloc(size_t size) {
+    void * ptr;
+    CUDA_CHECK(cudaMallocHost((void **) &ptr, size));
+    return ptr;
+}
+
+void ggml_cuda_host_free(void * ptr) {
+    CUDA_CHECK(cudaFreeHost(ptr));
+}
index 1fd67ebeb71cc8ee01ae33c527ca2c8f9c0db768..36782d9e796b7ed873a036cc3cfb542ead3e1116 100644 (file)
@@ -26,9 +26,14 @@ extern "C" {
     } while (0)
 
 extern cublasHandle_t g_cublasH;
-extern cudaStream_t   g_cudaStream;
+extern cudaStream_t g_cudaStream;
+extern cudaStream_t g_cudaStream2;
+extern cudaEvent_t g_cudaEvent;
 
 void   ggml_init_cublas(void);
+void * ggml_cuda_host_malloc(size_t size);
+void   ggml_cuda_host_free(void * ptr);
+
 void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
 void   ggml_cuda_pool_free(void * ptr, size_t size);
 
@@ -41,6 +46,9 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st
 
 cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
 
+typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
+dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
+
 #ifdef  __cplusplus
 }
 #endif
diff --git a/ggml.c b/ggml.c
index 4ec637ee1e0821eab7e5e3df2d951402557092dd..64ecd0867e4fb0571141140f558c47656388e8a1 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -8033,7 +8033,7 @@ static void ggml_compute_forward_mul_mat_f32(
 #if defined(GGML_USE_CUBLAS)
         const float alpha = 1.0f;
         const float beta = 0.0f;
-        const int x_ne = ne01 * ne10;
+        const int x_ne = ne01 * ne00;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
@@ -8235,25 +8235,27 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
-        ggml_fp16_t * const wdata = params->wdata;
-
         const float alpha = 1.0f;
         const float beta = 0.0f;
-        const int x_ne = ne01 * ne10;
+        const int x_ne = ne01 * ne00;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
         size_t x_size, y_size, d_size;
-        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
-        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
-        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+        ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float       * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
 #else
         float * const wdata = params->wdata;
 #endif
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
 #if defined(GGML_USE_CUBLAS)
+                // copy src0 while converting src1
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
+
                 // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
+                ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02);
                 {
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne11; ++i01) {
@@ -8275,11 +8277,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
 #if defined(GGML_USE_CUBLAS)
                 const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
-
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
                 // copy data to device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
 
                 // compute
@@ -8498,39 +8498,19 @@ static void ggml_compute_forward_mul_mat_q_f32(
 #if defined(GGML_USE_CUBLAS)
         const float alpha = 1.0f;
         const float beta = 0.0f;
-        const int x_ne = ne01 * ne10;
+        const int x_ne = ne01 * ne00;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
         size_t x_size, y_size, d_size, q_size;
-        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
-        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
-        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
-        float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
+        float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+        void  * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
 
-        void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream)  = NULL;
-        if (type == GGML_TYPE_Q4_0) {
-            dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
-        }
-        else if (type == GGML_TYPE_Q4_1) {
-            dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
-        }
-        else if (type == GGML_TYPE_Q4_2) {
-            dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
-        }
-        else if (type == GGML_TYPE_Q5_0) {
-            dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
-        }
-        else if (type == GGML_TYPE_Q5_1) {
-            dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
-        }
-        else if (type == GGML_TYPE_Q8_0) {
-            dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
-        }
-        else {
-            GGML_ASSERT(false);
-        }
-#elif !defined(GGML_USE_CLBLAST)
+        const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
+        GGML_ASSERT(dequantize_row_q_cuda != NULL);
+#else
         float * const wdata = params->wdata;
         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
 #endif
@@ -8543,10 +8523,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
 #if defined(GGML_USE_CUBLAS)
                 // copy and dequantize on device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2));
 
-                dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
+                dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
                 CUDA_CHECK(cudaGetLastError());
+                CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
 #elif defined(GGML_USE_CLBLAST)
                 const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
 #else
@@ -8560,11 +8541,13 @@ static void ggml_compute_forward_mul_mat_q_f32(
                 const float * x = wdata;
 #endif
 
-
 #if defined(GGML_USE_CUBLAS)
                 // copy data to device
                 CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
 
+                // wait for dequantization
+                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
+
                 // compute
                 CUBLAS_CHECK(
                     cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
@@ -11588,7 +11571,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1; // TODO: this actually is doing nothing
                                                    //       the threads are still spinning
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0));
                                 //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
                                 //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
                                 //printf("cur = %zu\n", cur);
@@ -11600,6 +11583,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 #endif
                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                                node->n_tasks = 1;
+                            }
+#endif
                         } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
index 45f0d44acc5481ad2075bf70096d6f114aad9078..4699e5cf1de7c4299ce1ddd64c317727b6de5166 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -136,7 +136,7 @@ struct llama_kv_cache {
 
     struct ggml_context * ctx = NULL;
 
-    llama_buffer buf;
+    llama_ctx_buffer buf;
 
     int n; // number of tokens currently in the cache
 
@@ -167,7 +167,7 @@ struct llama_model {
     struct llama_kv_cache kv_self;
 
     // the model memory buffer
-    llama_buffer buf;
+    llama_ctx_buffer buf;
 
     // model memory mapped file
     std::unique_ptr<llama_mmap> mapping;
@@ -228,8 +228,8 @@ struct llama_context {
 
     // memory buffers used to evaluate the model
     // TODO: move in llama_state
-    llama_buffer buf_compute;
-    llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
+    llama_ctx_buffer buf_compute;
+    llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
 
     int    buf_last = 0;
     size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
index acb207e653c1063cfca1bac30b559171566916fe..6e66d12a8041ce9c6c258100f74e6fa549fdaa20 100755 (executable)
@@ -405,4 +405,30 @@ struct llama_buffer {
         delete[] addr;
     }
 };
+
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
+struct llama_ctx_buffer {
+    uint8_t * addr = NULL;
+    size_t size = 0;
+
+    void resize(size_t size) {
+        if (addr) {
+            ggml_cuda_host_free(addr);
+        }
+        addr = (uint8_t *) ggml_cuda_host_malloc(size);
+        this->size = size;
+    }
+
+    ~llama_ctx_buffer() {
+        if (addr) {
+            ggml_cuda_host_free(addr);
+        }
+    }
+};
+#else
+typedef llama_buffer llama_ctx_buffer;
+#endif
+
+
 #endif