]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : GPU-accelerated token generation (#1412)
authorJohannes Gäßler <redacted>
Sat, 13 May 2023 13:38:36 +0000 (15:38 +0200)
committerGitHub <redacted>
Sat, 13 May 2023 13:38:36 +0000 (16:38 +0300)
* CUDA kernel for q4_0 dequant. + mat. vec. mult.

* Added q4_1 via template

* Added missing __syncthreads();

* --gpu_layers -> --gpu-layers

* Shorter dequantize_mul_mat_vec line

* q5_0 dequantize_mul_mat kernel

* More readable dequantize_mul_mat_vec logic

* dequantize_mul_mat_vec kernels for q5_1, q8_0, f16

* llama : offload "output" tensor to GPU too + coding style fixes

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/common.cpp
examples/common.h
ggml-cuda.cu
ggml-cuda.h
ggml.c
ggml.h
llama.cpp
llama.h

index 80e35d2e9cec8f90846f8bc09cb66b0b7df9575f..86c1eef41b475200fbdb34da4e3fafb6d5e04580 100644 (file)
@@ -277,6 +277,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.use_color = true;
         } else if (arg == "--mlock") {
             params.use_mlock = true;
+        } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_gpu_layers = std::stoi(argv[i]);
         } else if (arg == "--no-mmap") {
             params.use_mmap = false;
         } else if (arg == "--mtest") {
@@ -421,6 +427,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     if (llama_mmap_supported()) {
         fprintf(stderr, "  --no-mmap             do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
     }
+    fprintf(stderr, "  -ngl N, --n-gpu-layers N\n");
+    fprintf(stderr, "                        number of layers to store in VRAM\n");
     fprintf(stderr, "  --mtest               compute maximum memory usage\n");
     fprintf(stderr, "  --verbose-prompt      print prompt before generation\n");
     fprintf(stderr, "  --lora FNAME          apply LoRA adapter (implies --no-mmap)\n");
@@ -463,14 +471,15 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
 struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
     auto lparams = llama_context_default_params();
 
-    lparams.n_ctx      = params.n_ctx;
-    lparams.n_parts    = params.n_parts;
-    lparams.seed       = params.seed;
-    lparams.f16_kv     = params.memory_f16;
-    lparams.use_mmap   = params.use_mmap;
-    lparams.use_mlock  = params.use_mlock;
-    lparams.logits_all = params.perplexity;
-    lparams.embedding  = params.embedding;
+    lparams.n_ctx        = params.n_ctx;
+    lparams.n_parts      = params.n_parts;
+    lparams.n_gpu_layers = params.n_gpu_layers;
+    lparams.seed         = params.seed;
+    lparams.f16_kv       = params.memory_f16;
+    lparams.use_mmap     = params.use_mmap;
+    lparams.use_mlock    = params.use_mlock;
+    lparams.logits_all   = params.perplexity;
+    lparams.embedding    = params.embedding;
 
     llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
 
index 499671b2e8d6dccf067c345dde7d4745985649a2..717838f06e0641f81a1e367a5a7f5188987efabd 100644 (file)
 int32_t get_num_physical_cores();
 
 struct gpt_params {
-    int32_t seed          = -1;   // RNG seed
+    int32_t seed          = -1;  // RNG seed
     int32_t n_threads     = get_num_physical_cores();
     int32_t n_predict     = -1;  // new tokens to predict
-    int32_t n_parts       = -1;   // amount of model parts (-1 = determine from model dimensions)
-    int32_t n_ctx         = 512;  // context size
-    int32_t n_batch       = 512;  // batch size for prompt processing (must be >=32 to use BLAS)
-    int32_t n_keep        = 0;    // number of tokens to keep from initial prompt
+    int32_t n_parts       = -1;  // amount of model parts (-1 = determine from model dimensions)
+    int32_t n_ctx         = 512; // context size
+    int32_t n_batch       = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_keep        = 0;   // number of tokens to keep from initial prompt
+    int32_t n_gpu_layers  = 0;   // number of layers to store in VRAM
 
     // sampling parameters
     std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
index 8a3beb0e54b880374536aa1a0670bdf74e0ec8d0..b6a7754d534e681df8c3f089b31ce50034eb9ef8 100644 (file)
@@ -32,9 +32,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
         }                                                                               \
     } while (0)
 
+typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
 typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
+typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
+
+// QK = number of values after dequantization
+// QR = QK / number of values before dequantization
 
 #define QK4_0 32
+#define QR4_0 2
 typedef struct {
     float   d;              // delta
     uint8_t qs[QK4_0 / 2];  // nibbles / quants
@@ -42,6 +48,7 @@ typedef struct {
 static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
 
 #define QK4_1 32
+#define QR4_1 2
 typedef struct {
     float   d;              // delta
     float   m;              // min
@@ -50,6 +57,7 @@ typedef struct {
 static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
 
 #define QK5_0 32
+#define QR5_0 2
 typedef struct {
     half d;                 // delta
     uint8_t qh[4];          // 5-th bit of quants
@@ -58,6 +66,7 @@ typedef struct {
 static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
 
 #define QK5_1 32
+#define QR5_1 2
 typedef struct {
     half d;                 // delta
     half m;                 // min
@@ -67,12 +76,100 @@ typedef struct {
 static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
 
 #define QK8_0 32
+#define QR8_0 1
 typedef struct {
     float   d;              // delta
     int8_t  qs[QK8_0];      // quants
 } block_q8_0;
 static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
 
+#define CUDA_DMMV_BLOCK_SIZE 32
+
+static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const float d = x[ib].d;
+
+    const uint8_t vui = x[ib].qs[iqs];
+
+    const int8_t vi0 = vui & 0xF;
+    const int8_t vi1 = vui >> 4;
+
+    v0 = (vi0 - 8)*d;
+    v1 = (vi1 - 8)*d;
+}
+
+static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const float d = x[ib].d;
+    const float m = x[ib].m;
+
+    const uint8_t vui = x[ib].qs[iqs];
+
+    const int8_t vi0 = vui & 0xF;
+    const int8_t vi1 = vui >> 4;
+
+    v0 = vi0*d + m;
+    v1 = vi1*d + m;
+}
+
+static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const float d = x[ib].d;
+
+    uint32_t qh;
+    memcpy(&qh, x[ib].qh, sizeof(qh));
+
+    const uint8_t xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const uint8_t xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
+    const int32_t x1 = ((x[ib].qs[iqs] >>  4) | xh_1) - 16;
+
+    v0 = x0*d;
+    v1 = x1*d;
+}
+
+static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const float d = x[ib].d;
+    const float m = x[ib].m;
+
+    uint32_t qh;
+    memcpy(&qh, x[ib].qh, sizeof(qh));
+
+    const uint8_t xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const uint8_t xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    const int32_t x1 = ((x[ib].qs[iqs] >>  4) | xh_1);
+
+    v0 = x0*d + m;
+    v1 = x1*d + m;
+}
+
+static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const float d = x[ib].d;
+
+    const int8_t vi0 = x[ib].qs[iqs + 0];
+    const int8_t vi1 = x[ib].qs[iqs + 1];
+
+    v0 = vi0*d;
+    v1 = vi1*d;
+}
+
+static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+    const half * x = (const half *) vx;
+
+    v0 = __half2float(x[ib + 0]);
+    v1 = __half2float(x[ib + 1]);
+}
+
 static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
     static const int qk = QK4_0;
 
@@ -173,6 +270,44 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
     }
 }
 
+template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
+    const int row = blockIdx.x;
+    const int tid = threadIdx.x;
+
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    __shared__ float tmp[block_size]; // separate sum for each thread
+    tmp[tid] = 0;
+
+    for (int i = 0; i < ncols/block_size; i += 2) {
+        const int col = i*block_size + 2*tid;
+        const int ib = (row*ncols + col)/qk; // block index
+        const int iqs = (col%qk)/qr; // quant index
+        const int iybs = col - col%qk; // y block start index
+
+        // dequantize
+        float v0, v1;
+        dequantize_kernel(vx, ib, iqs, v0, v1);
+
+        // matrix multiplication
+        tmp[tid] += v0 * y[iybs + iqs + 0];
+        tmp[tid] += v1 * y[iybs + iqs + y_offset];
+    }
+
+    // sum up partial sums and write back result
+    __syncthreads();
+    for (int s=block_size/2; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        __syncthreads();
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
+}
+
 static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
     const int nb = k / QK4_0;
     dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
@@ -198,6 +333,36 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
     dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
 }
 
+static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
+        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
+        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
+        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
+        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
+        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
 // TODO: optimize
 static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
     const half * x = (const half *) vx;
@@ -211,6 +376,12 @@ static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStre
     convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
 }
 
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+    dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
+        <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
 static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_Q4_0:
@@ -230,8 +401,27 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     }
 }
 
+static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_mul_mat_vec_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_mul_mat_vec_q4_1_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_mul_mat_vec_q5_0_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_mul_mat_vec_q5_1_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_mul_mat_vec_q8_0_cuda;
+        case GGML_TYPE_F16:
+            return dequantize_mul_mat_vec_q8_0_cuda;
+        default:
+            return nullptr;
+    }
+}
+
 // buffer pool for cuda
-#define MAX_CUDA_BUFFERS 16
+#define MAX_CUDA_BUFFERS 256
 
 struct scoped_spin_lock {
     std::atomic_flag& lock;
@@ -528,6 +718,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
     const int nb2  = dst->nb[2];
     const int nb3  = dst->nb[3];
     const ggml_type type = src0->type;
+    const bool mul_mat_vec = ne11 == 1;
 
     const float alpha = 1.0f;
     const float beta = 0.0f;
@@ -538,12 +729,16 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
     const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
 
     size_t x_size, y_size, d_size, q_size;
-    float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
+    float * d_X = nullptr;
+    if (!mul_mat_vec) {
+        d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
+    }
     float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
     float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
     char  * d_Q = (char  *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
 
     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
+    dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
     GGML_ASSERT(to_fp32_cuda != nullptr);
 
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -553,31 +748,54 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
             cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
             cudaEvent_t  cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
 
-            float * c_X = d_X + i * x_ne;
             float * c_Y = d_Y + i * y_ne;
             float * c_D = d_D + i * d_ne;
             char  * c_Q = d_Q + i * q_sz;
 
-            // copy src0 and convert to fp32 on device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
-            to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
-            CUDA_CHECK(cudaGetLastError());
-            CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+            // copy src0 to device if necessary
+            if (src0->backend == GGML_BACKEND_CPU) {
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
+            } else if (src0->backend == GGML_BACKEND_CUDA) {
+                c_Q = ((char *) src0->data) + i * q_sz;
+            } else {
+                GGML_ASSERT(false);
+            }
+            if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
+                CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
 
-            // copy src1 to device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+                // copy src1 to device
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
 
-            // wait for conversion
-            CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+                // wait for data
+                CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
 
-            // compute
-            CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
-            CUBLAS_CHECK(
-                cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                        ne01, ne11, ne10,
-                        &alpha, c_X, ne00,
-                                c_Y, ne10,
-                        &beta,  c_D, ne01));
+                // compute
+                dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
+                CUDA_CHECK(cudaGetLastError());
+
+            } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
+                float * c_X = d_X + i * x_ne;
+
+                // convert src0 to fp32 on device
+                to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
+                CUDA_CHECK(cudaGetLastError());
+                CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+
+                // copy src1 to device
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+
+                // wait for conversion
+                CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+
+                // compute
+                CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
+                CUBLAS_CHECK(
+                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha, c_X, ne00,
+                                    c_Y, ne10,
+                            &beta,  c_D, ne01));
+            }
 
             // copy dst to host
             float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
@@ -586,7 +804,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
     }
 
     CUDA_CHECK(cudaDeviceSynchronize());
-    ggml_cuda_pool_free(d_X, x_size);
+    if (!mul_mat_vec) {
+        ggml_cuda_pool_free(d_X, x_size);
+    }
     ggml_cuda_pool_free(d_Y, y_size);
     ggml_cuda_pool_free(d_D, d_size);
     ggml_cuda_pool_free(d_Q, q_size);
@@ -602,8 +822,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
     if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
         src1->type == GGML_TYPE_F32 &&
         dst->type == GGML_TYPE_F32 &&
-        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
-
+        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
         return true;
     }
 
@@ -655,3 +874,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
         return 0;
     }
 }
+
+void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
+    const int64_t ne0 = tensor->ne[0];
+    const int64_t ne1 = tensor->ne[1];
+    const int64_t ne2 = tensor->ne[2];
+    const int64_t ne3 = tensor->ne[3];
+
+    const ggml_type type = tensor->type;
+    const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
+
+    size_t q_size;
+    char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
+
+    cudaStream_t cudaStream2 = g_cudaStreams2[0];
+
+    // copy tensor to device
+    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
+    CUDA_CHECK(cudaDeviceSynchronize());
+
+    tensor->data = d_Q;
+    tensor->backend = GGML_BACKEND_CUDA;
+}
index f7d6a8bc1842ac2ba9fb350215d5e7ee2e8f18ac..4e2c24283ccf4b29d180d3f3f0625183f128fb88 100644 (file)
@@ -14,6 +14,8 @@ void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
 void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
+void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
+
 #ifdef  __cplusplus
 }
 #endif
diff --git a/ggml.c b/ggml.c
index 675eb0d2f46e56780d548a5eb8daf404b769d98b..05746383974a0b8a1fa64a018e4f8141d9734093 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -3882,6 +3882,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
 
     *result = (struct ggml_tensor) {
         /*.type         =*/ type,
+        /*.backend      =*/ GGML_BACKEND_CPU,
         /*.n_dims       =*/ n_dims,
         /*.ne           =*/ { 1, 1, 1, 1 },
         /*.nb           =*/ { 0, 0, 0, 0 },
diff --git a/ggml.h b/ggml.h
index 2745fb30be56feeb6a52e3d10302cecb13f8b48e..967ef72d034dd45b223f5c3641fb888efe42c091 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -243,6 +243,11 @@ extern "C" {
         GGML_TYPE_COUNT,
     };
 
+    enum ggml_backend {
+        GGML_BACKEND_CPU = 0,
+        GGML_BACKEND_CUDA = 1,
+    };
+
     // model file types
     enum ggml_ftype {
         GGML_FTYPE_UNKNOWN     = -1,
@@ -333,6 +338,7 @@ extern "C" {
     // n-dimensional tensor
     struct ggml_tensor {
         enum ggml_type type;
+        enum ggml_backend backend;
 
         int     n_dims;
         int64_t ne[GGML_MAX_DIMS]; // number of elements
@@ -363,7 +369,7 @@ extern "C" {
 
         char name[32];
 
-        char padding[8]; // TODO: remove and add padding to name?
+        char padding[9]; // TODO: remove and add padding to name?
     };
 
     // computation graph
index 08c735234c80680eefbba223db27cbf4bbaf4564..73b932a74eea3f19f677cc47a6961e7b5c1ee375 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -9,6 +9,9 @@
 #include "llama.h"
 
 #include "ggml.h"
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
+#endif
 
 #include <array>
 #include <ctime>
@@ -810,6 +813,7 @@ struct llama_context_params llama_context_default_params() {
     struct llama_context_params result = {
         /*.n_ctx                       =*/ 512,
         /*.n_parts                     =*/ -1,
+        /*.gpu_layers                  =*/ 0,
         /*.seed                        =*/ -1,
         /*.f16_kv                      =*/ false,
         /*.logits_all                  =*/ false,
@@ -876,6 +880,7 @@ static void llama_model_load_internal(
         const std::string & fname,
         llama_context & lctx,
         int n_ctx,
+        int n_gpu_layers,
         ggml_type memory_type,
         bool use_mmap,
         bool use_mlock,
@@ -1022,6 +1027,33 @@ static void llama_model_load_internal(
     ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
 
     model.mapping = std::move(ml->mapping);
+#ifdef GGML_USE_CUBLAS
+    {
+        const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
+
+        fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu);
+
+        size_t vram_total = 0;
+
+        for (int i = 0; i < n_gpu; ++i) {
+            const auto & layer = model.layers[i];
+
+            ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq);
+            ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk);
+            ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv);
+            ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo);
+            ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1);
+            ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2);
+            ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3);
+        }
+        if (n_gpu_layers > (int) hparams.n_layer) {
+            fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__);
+            ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output);
+        }
+
+        fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
+    }
+#endif
 
     // loading time will be recalculate after the first eval, so
     // we take page faults deferred by mmap() into consideration
@@ -1032,6 +1064,7 @@ static bool llama_model_load(
         const std::string & fname,
         llama_context & lctx,
         int n_ctx,
+        int n_gpu_layers,
         ggml_type memory_type,
         bool use_mmap,
         bool use_mlock,
@@ -1039,7 +1072,7 @@ static bool llama_model_load(
         llama_progress_callback progress_callback,
         void *progress_callback_user_data) {
     try {
-        llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock,
+        llama_model_load_internal(fname, lctx, n_ctx, n_gpu_layers, memory_type, use_mmap, use_mlock,
                                   vocab_only, progress_callback, progress_callback_user_data);
         return true;
     } catch (const std::string & err) {
@@ -2111,7 +2144,7 @@ struct llama_context * llama_init_from_file(
 
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
-    if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type,
+    if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
                           params.use_mmap, params.use_mlock, params.vocab_only,
                           params.progress_callback, params.progress_callback_user_data)) {
         fprintf(stderr, "%s: failed to load model\n", __func__);
diff --git a/llama.h b/llama.h
index ca05645b974dedd6f8125cdeebdc521103253176..21cba8cf61061a0c2263054762671d632444cc34 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -54,9 +54,10 @@ extern "C" {
     typedef void (*llama_progress_callback)(float progress, void *ctx);
 
     struct llama_context_params {
-        int n_ctx;   // text context
-        int n_parts; // -1 for default
-        int seed;    // RNG seed, -1 for random
+        int n_ctx;        // text context
+        int n_parts;      // -1 for default
+        int n_gpu_layers; // number of layers to store in VRAM
+        int seed;         // RNG seed, -1 for random
 
         bool f16_kv;     // use fp16 for KV cache
         bool logits_all; // the llama_eval() call computes all logits, not just the last one