]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Support dup & cont ops on CUDA (#2242)
authorJiahao Li <redacted>
Mon, 17 Jul 2023 17:39:29 +0000 (01:39 +0800)
committerGitHub <redacted>
Mon, 17 Jul 2023 17:39:29 +0000 (20:39 +0300)
ggml-cuda.cu

index 0646fa7b2582dcf06170b5fe2c71a71daed0e538..d3054a7fac71d344918303480240bcd24d4d2b9b 100644 (file)
@@ -3537,6 +3537,11 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
     (void) dst;
 }
 
+void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_cpy(src0, dst, nullptr);
+    (void) src1;
+}
+
 void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
@@ -3670,7 +3675,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
     // recursively assign CUDA buffers until a compute tensor is found
     if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
         const ggml_op src0_op = tensor->src[0]->op;
-        if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
+        if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
             ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
         }
     }
@@ -3776,6 +3781,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
 
     switch (tensor->op) {
+        case GGML_OP_DUP:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_dup;
+            break;
         case GGML_OP_ADD:
             if (!any_on_device) {
                 return false;
@@ -3830,6 +3841,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_cpy;
             break;
+        case GGML_OP_CONT:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_dup;
+            break;
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE: