]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: mul_mat_id always on GPU for batches >= 32 (#4553)
authorJohannes Gäßler <redacted>
Thu, 21 Dec 2023 17:42:59 +0000 (18:42 +0100)
committerGitHub <redacted>
Thu, 21 Dec 2023 17:42:59 +0000 (18:42 +0100)
ggml-cuda.cu

index 1ca071d90b935a113d7462cbac1aa019e5b4f03e..036668bfddefc05ea16970a2bfee3d1e594a61f3 100644 (file)
@@ -8773,8 +8773,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
     // TODO: mmq/mmv support
 #endif
 
-    GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
-
     const int64_t nb11 = src1->nb[1];
     const int64_t nb1  =  dst->nb[1];
 
@@ -8803,13 +8801,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
     ggml_tensor src1_row = *src1;
     ggml_tensor dst_row = *dst;
 
+    src1_row.backend = GGML_BACKEND_GPU;
+    dst_row.backend  = GGML_BACKEND_GPU;
+
     src1_row.extra = &src1_row_extra;
     dst_row.extra = &dst_row_extra;
 
-    char * src1_original = (char *) src1_extra->data_device[g_main_device];
-    char * dst_original  = (char *)  dst_extra->data_device[g_main_device];
+    char * src1_original = src1->backend == GGML_BACKEND_CPU ?
+        (char *) src1->data : (char *) src1_extra->data_device[g_main_device];
+    char * dst_original  =  dst->backend == GGML_BACKEND_CPU ?
+        (char *)  dst->data : (char *)  dst_extra->data_device[g_main_device];
 
     if (src1->ne[1] == 1) {
+        GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
+        GGML_ASSERT(dst->backend  == GGML_BACKEND_GPU);
+
         for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
             //int32_t row_id;
             //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
@@ -8837,6 +8843,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
         src1_row_extra.data_device[g_main_device] = src1_contiguous;
         dst_row_extra.data_device[g_main_device]  =  dst_contiguous;
 
+        const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
+            cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
+        const cudaMemcpyKind dst_kind  =  dst->backend == GGML_BACKEND_CPU ?
+            cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
+
         for (int32_t row_id = 0; row_id < n_as; ++row_id) {
             const struct ggml_tensor * src0_row = dst->src[row_id + 2];
 
@@ -8851,7 +8862,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
                 GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
                 CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
-                                        nb11, cudaMemcpyDeviceToDevice, stream));
+                                        nb11, src1_kind, stream));
                 num_src1_rows++;
             }
 
@@ -8883,7 +8894,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
                 GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
                 CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
-                                        nb1, cudaMemcpyDeviceToDevice, stream));
+                                        nb1, dst_kind, stream));
                 num_src1_rows++;
             }
         }
@@ -8891,6 +8902,10 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
         ggml_cuda_pool_free(src1_contiguous, as_src1);
         ggml_cuda_pool_free(dst_contiguous,  as_dst);
     }
+
+    if (dst->backend == GGML_BACKEND_CPU) {
+        CUDA_CHECK(cudaStreamSynchronize(stream));
+    }
 }
 
 static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9289,7 +9304,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
         || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
 
-    if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
+    if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
         return false;
     }