]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA GPU acceleration for LoRAs + f16 models (#1970)
authorJohannes Gäßler <redacted>
Wed, 28 Jun 2023 16:35:54 +0000 (18:35 +0200)
committerGitHub <redacted>
Wed, 28 Jun 2023 16:35:54 +0000 (18:35 +0200)
examples/common.cpp
ggml-cuda.cu
ggml-cuda.h
llama.cpp

index 0023027341e5f46ee79f09d918866b59a6bb2c06..5addd10a13fe9662674c8f0384ebe0951f6fa816 100644 (file)
@@ -416,13 +416,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         exit(1);
     }
 
-#ifdef GGML_USE_CUBLAS
-    if (!params.lora_adapter.empty() && params.n_gpu_layers > 0) {
-        fprintf(stderr, "%s: error: the simultaneous use of LoRAs and GPU acceleration is not supported", __func__);
-        exit(1);
-    }
-#endif // GGML_USE_CUBLAS
-
     if (escape_prompt) {
         process_escapes(params.prompt);
     }
index c34e96abfd08c105407164f244d11d6622af4420..be75cb792e04e0a42ea48ca29136600920ec8756 100644 (file)
@@ -223,6 +223,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
     dst[i] = x[i] + y[i];
 }
 
+static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = __hadd(x[i], __float2half(y[i]));
+}
+
 static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -1459,6 +1468,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
     add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
 }
 
+static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+}
+
 static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
     const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
     mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -1941,7 +1955,7 @@ inline void ggml_cuda_op_add(
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
     cudaStream_t & cudaStream_main){
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
     GGML_ASSERT(src1_ddf_i != nullptr);
     GGML_ASSERT(dst_ddf_i != nullptr);
 
@@ -1949,7 +1963,13 @@ inline void ggml_cuda_op_add(
     const int64_t i01_diff = i01_high - i01_low;
 
     // compute
-    add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+        add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
+    } else {
+        GGML_ASSERT(false);
+    }
     CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
@@ -2547,8 +2567,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
 }
 
 void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
+    // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
+    // Due to flatten_rows == true this does in practice not make a difference however.
+    // Better solution would be nice but right now that would require disproportionate changes.
+    GGML_ASSERT(
+        (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
+        src1->type == GGML_TYPE_F32 &&
+        (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
 }
 
 void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2801,7 +2827,7 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
     delete extra;
 }
 
-void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
+void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
     if (scratch && g_scratch_size == 0) {
         return;
     }
@@ -2810,11 +2836,11 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
     if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
         const ggml_op src0_op = tensor->src0->op;
         if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
-            ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
+            ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
         }
     }
     if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
-        ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
+        ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
     }
 
     tensor->backend = GGML_BACKEND_GPU;
@@ -2822,11 +2848,12 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
     memset(extra, 0, sizeof(*extra));
 
     const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
-        tensor->op == GGML_OP_VIEW;
+        tensor->op == GGML_OP_VIEW ||
+        force_inplace;
     const size_t size = ggml_nbytes(tensor);
 
     CUDA_CHECK(cudaSetDevice(g_main_device));
-    if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
+    if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
         struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
         char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
         size_t offset = 0;
@@ -2865,11 +2892,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
 }
 
 void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, true);
+    ggml_cuda_assign_buffers_impl(tensor, true, false);
 }
 
 void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, false);
+    ggml_cuda_assign_buffers_impl(tensor, false, false);
+}
+
+void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
+    ggml_cuda_assign_buffers_impl(tensor, false, true);
 }
 
 void ggml_cuda_set_main_device(int main_device) {
index d32b4484267ab5a3f9028db07330302432c1fd93..7a65a3558a074d5d9d4052e2e2fdd1075d40a168 100644 (file)
@@ -29,6 +29,7 @@ void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
 void   ggml_cuda_free_data(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
+void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
 void   ggml_cuda_set_main_device(int main_device);
 void   ggml_cuda_set_scratch_size(size_t scratch_size);
 void   ggml_cuda_free_scratch(void);
index 5a142aba642811d14e46fc1683ccfe31db0f1b99..5f3761b0ead24f7f81eb6660e9bbd42757202094 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2976,7 +2976,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
                         return false;
                     }
         }
-        ggml_tensor* lora_tensor;
+        ggml_tensor * lora_tensor;
         if (n_dims == 2) {
             lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
         }
@@ -2984,6 +2984,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
             fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims);
             return 1;
         }
+        ggml_set_name(lora_tensor, "lora_tensor");
 
         // load tensor data
         size_t offset = fin.tellg();
@@ -2999,6 +3000,21 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
             lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {
 
             ggml_tensor * dest_t = model_tensors[base_name];
+
+            offload_func_t offload_func = llama_nop;
+            offload_func_t offload_func_force_inplace = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+            if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) {
+                if (dest_t->type != GGML_TYPE_F16) {
+                    throw std::runtime_error(format(
+                        "%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models", __func__));
+                }
+                offload_func = ggml_cuda_assign_buffers;
+                offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace;
+            }
+#endif // GGML_USE_CUBLAS
+
             ggml_tensor * base_t;
             if (model_loader) {
                 // load from base model
@@ -3026,7 +3042,12 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
             }
 
             ggml_tensor * loraA = lora_tensors[base_name + ".loraA"];
+            GGML_ASSERT(loraA->type == GGML_TYPE_F32);
+            ggml_set_name(loraA, "loraA");
+
             ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];
+            GGML_ASSERT(loraB->type == GGML_TYPE_F32);
+            ggml_set_name(loraB, "loraB");
 
             if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
                 fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
@@ -3036,19 +3057,32 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
 
             // w = w + BA*s
             ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
+            offload_func(BA);
+            ggml_set_name(BA, "BA");
 
             if (scaling != 1.0f) {
                 ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
+                ggml_set_name(scale_tensor, "scale_tensor");
+
                 BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
+                offload_func(BA);
+                ggml_set_name(BA, "BA_scaled");
             }
 
             ggml_tensor * r;
             if (base_t == dest_t) {
                 r = ggml_add_inplace(lora_ctx, dest_t, BA);
+                offload_func_force_inplace(r);
+                ggml_set_name(r, "r_add_inplace");
             }
             else {
                 r = ggml_add(lora_ctx, base_t, BA);
+                offload_func(r);
+                ggml_set_name(r, "r_add");
+
                 r = ggml_cpy(lora_ctx, r, dest_t);
+                offload_func(r);
+                ggml_set_name(r, "r_cpy");
             }
 
             struct ggml_cgraph gf = ggml_build_forward(r);