]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add NVIDIA cuBLAS support (#1044)
authorslaren <redacted>
Wed, 19 Apr 2023 09:22:45 +0000 (11:22 +0200)
committerGitHub <redacted>
Wed, 19 Apr 2023 09:22:45 +0000 (11:22 +0200)
CMakeLists.txt
Makefile
ggml.c
ggml.h
llama.cpp

index ed9a3aa38229014a8c932249a7a55b7c05bd060b..8eadea4fd4862709cb823605b605366dd59337a7 100644 (file)
@@ -66,6 +66,7 @@ endif()
 # 3rd party libs
 option(LLAMA_ACCELERATE             "llama: enable Accelerate framework"                    ON)
 option(LLAMA_OPENBLAS               "llama: use OpenBLAS"                                   OFF)
+option(LLAMA_CUBLAS                 "llama: use cuBLAS"                                     OFF)
 
 option(LLAMA_BUILD_TESTS            "llama: build tests"    ${LLAMA_STANDALONE})
 option(LLAMA_BUILD_EXAMPLES         "llama: build examples" ${LLAMA_STANDALONE})
@@ -142,6 +143,26 @@ if (LLAMA_OPENBLAS)
     endif()
 endif()
 
+if (LLAMA_CUBLAS)
+    cmake_minimum_required(VERSION 3.17)
+
+    find_package(CUDAToolkit)
+    if (CUDAToolkit_FOUND)
+        message(STATUS "cuBLAS found")
+
+        add_compile_definitions(GGML_USE_CUBLAS)
+
+        if (LLAMA_STATIC)
+            set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+        else()
+            set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
+        endif()
+
+    else()
+        message(WARNING "cuBLAS not found")
+    endif()
+endif()
+
 if (LLAMA_ALL_WARNINGS)
     if (NOT MSVC)
         set(c_flags
index 071d9561ac57b220f97b4749fbf08ecd678ab215..deb0d00090f5aff5db9f24d81dfc6bb14161ddf3 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -97,6 +97,10 @@ ifdef LLAMA_OPENBLAS
        CFLAGS  += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
        LDFLAGS += -lopenblas
 endif
+ifdef LLAMA_CUBLAS
+       CFLAGS  += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
+       LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64
+endif
 ifdef LLAMA_GPROF
        CFLAGS   += -pg
        CXXFLAGS += -pg
diff --git a/ggml.c b/ggml.c
index f4b8fc2829463d2327a340c58f5190b1143bbd6d..13c1548fee895a37cb416fc1e9923b7e32b4f14a 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -142,10 +142,46 @@ inline static void* ggml_aligned_malloc(size_t size) {
         } \
     } while (0)
 
-#ifdef GGML_USE_ACCELERATE
+#if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
-#elif GGML_USE_OPENBLAS
+#elif defined(GGML_USE_OPENBLAS)
 #include <cblas.h>
+#elif defined(GGML_USE_CUBLAS)
+#include <cublas_v2.h>
+#include <cuda_runtime.h>
+#define CUDA_CHECK(err)                                                                            \
+    do {                                                                                           \
+        cudaError_t err_ = (err);                                                                  \
+        if (err_ != cudaSuccess) {                                                                 \
+            printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,                       \
+                cudaGetErrorString(err_));                                                         \
+            exit(1);                                                                               \
+        }                                                                                          \
+    } while (0)
+
+#define CUBLAS_CHECK(err)                                                                          \
+    do {                                                                                           \
+        cublasStatus_t err_ = (err);                                                               \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                                       \
+            printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);                        \
+            exit(1);                                                                               \
+        }                                                                                          \
+    } while (0)
+
+static cublasHandle_t cublasH = NULL;
+static cudaStream_t cudaStream = NULL;
+static void init_cublas(void) {
+    if (cublasH == NULL) {
+        // create cublas handle, bind a stream
+        CUBLAS_CHECK(cublasCreate(&cublasH));
+
+        CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
+        CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
+
+        // configure logging to stdout
+        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
+    }
+}
 #endif
 
 #undef MIN
@@ -3836,6 +3872,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
             GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
+        // initialize cuBLAS
+        #if defined(GGML_USE_CUBLAS)
+        init_cublas();
+        #endif
+
         is_first_call = false;
     }
 
@@ -7567,7 +7608,7 @@ static void ggml_compute_forward_rms_norm(
 
 // ggml_compute_forward_mul_mat
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
 // helper function to determine if it is better to use BLAS or not
 // for large matrices, BLAS is faster
 static bool ggml_compute_forward_mul_mat_use_blas(
@@ -7607,7 +7648,7 @@ static void ggml_compute_forward_mul_mat_f32(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
     const int64_t ne10 = src1->ne[0];
 #endif
     const int64_t ne11 = src1->ne[1];
@@ -7664,7 +7705,7 @@ static void ggml_compute_forward_mul_mat_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -7678,6 +7719,21 @@ static void ggml_compute_forward_mul_mat_f32(
             return;
         }
 
+#if defined(GGML_USE_CUBLAS)
+        float *d_X = NULL;
+        float *d_Y = NULL;
+        float *d_D = NULL;
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+        const int x_ne = ne01 * ne10;
+        const int y_ne = ne11 * ne10;
+        const int d_ne = ne11 * ne01;
+
+        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+#endif
+
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
                 const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
@@ -7685,15 +7741,37 @@ static void ggml_compute_forward_mul_mat_f32(
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
+#if defined(GGML_USE_CUBLAS)
+                // copy data to device
+                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+
+                // compute
+                CUBLAS_CHECK(
+                    cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha, d_X, ne00,
+                                    d_Y, ne10,
+                            &beta,  d_D, ne01));
+
+                // copy data to host
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaStreamSynchronize(cudaStream));
+#else
                 // zT = y * xT
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
+#endif
             }
         }
-
+#if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaFree(d_X));
+        CUDA_CHECK(cudaFree(d_Y));
+        CUDA_CHECK(cudaFree(d_D));
+#endif
         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
         return;
@@ -7823,7 +7901,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         GGML_ASSERT(nb10 == sizeof(float));
 
@@ -7839,10 +7917,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
             return;
         }
 
-        float * const wdata = params->wdata;
+#if defined(GGML_USE_CUBLAS)
+        ggml_fp16_t * const wdata = params->wdata;
 
+        float *d_X = NULL;
+        float *d_Y = NULL;
+        float *d_D = NULL;
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+        const int x_ne = ne01 * ne10;
+        const int y_ne = ne11 * ne10;
+        const int d_ne = ne11 * ne01;
+
+        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+#else
+        float * const wdata = params->wdata;
+#endif
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
+#if defined(GGML_USE_CUBLAS)
+                // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
+                {
+                    size_t id = 0;
+                    for (int64_t i01 = 0; i01 < ne11; ++i01) {
+                        for (int64_t i00 = 0; i00 < ne10; ++i00) {
+                            wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
+                        }
+                    }
+                }
+#else
                 {
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7851,7 +7956,32 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                         }
                     }
                 }
+#endif
 
+#if defined(GGML_USE_CUBLAS)
+                const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
+                const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // copy data to device
+                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+
+                // compute
+                CUBLAS_CHECK(
+                    cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha, d_X, CUDA_R_16F, ne00,
+                                    d_Y, CUDA_R_16F, ne10,
+                            &beta,  d_D, CUDA_R_32F, ne01,
+                            CUBLAS_COMPUTE_32F,
+                            CUBLAS_GEMM_DEFAULT));
+
+                // copy data to host
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaStreamSynchronize(cudaStream));
+#else
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
@@ -7863,9 +7993,15 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
+#endif
             }
         }
 
+#if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaFree(d_X));
+        CUDA_CHECK(cudaFree(d_Y));
+        CUDA_CHECK(cudaFree(d_D));
+#endif
         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
 
         return;
@@ -8017,7 +8153,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -8034,6 +8170,21 @@ static void ggml_compute_forward_mul_mat_q_f32(
         float * const wdata = params->wdata;
         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
 
+#if defined(GGML_USE_CUBLAS)
+        float *d_X = NULL;
+        float *d_Y = NULL;
+        float *d_D = NULL;
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+        const int x_ne = ne01 * ne10;
+        const int y_ne = ne11 * ne10;
+        const int d_ne = ne11 * ne01;
+
+        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+#endif
+
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
                 {
@@ -8049,15 +8200,38 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
+#if defined(GGML_USE_CUBLAS)
+                // copy data to device
+                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+
+                // compute
+                CUBLAS_CHECK(
+                    cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha, d_X, ne00,
+                                    d_Y, ne10,
+                            &beta,  d_D, ne01));
+
+                // copy data to host
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaStreamSynchronize(cudaStream));
+#else
                 // zT = y * xT
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
                                  x, ne00,
                         0.0f,    d, ne01);
+#endif
             }
         }
 
+#if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaFree(d_X));
+        CUDA_CHECK(cudaFree(d_Y));
+        CUDA_CHECK(cudaFree(d_D));
+#endif
         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
         return;
@@ -10874,7 +11048,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         size_t cur = 0;
 
                         if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1; // TODO: this actually is doing nothing
                                                    //       the threads are still spinning
@@ -10891,7 +11065,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
                         } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
@@ -12231,7 +12405,15 @@ int ggml_cpu_has_wasm_simd(void) {
 }
 
 int ggml_cpu_has_blas(void) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_cublas(void) {
+#if defined(GGML_USE_CUBLAS)
     return 1;
 #else
     return 0;
diff --git a/ggml.h b/ggml.h
index 603be84531b145b194fb3c270eab0f6c236aac55..570147fc246587603042beed51a0cdb61bae17f1 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -825,6 +825,7 @@ int ggml_cpu_has_f16c(void);
 int ggml_cpu_has_fp16_va(void);
 int ggml_cpu_has_wasm_simd(void);
 int ggml_cpu_has_blas(void);
+int ggml_cpu_has_cublas(void);
 int ggml_cpu_has_sse3(void);
 int ggml_cpu_has_vsx(void);
 
index f14324fcd1eababf2d8b886bc237adf4fd3eb704..3ff5dc1e14800952e24b19647c87883c6803900f 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1069,7 +1069,7 @@ static bool llama_eval_internal(
     // for big prompts, if BLAS is enabled, it is better to use only one thread
     // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
     ggml_cgraph gf = {};
-    gf.n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads;
+    gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_cublas() ? 1 : n_threads;
 
     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
     memcpy(embd->data, tokens, N*ggml_element_size(embd));