]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuBLAS: non-contiguous tensor support (#1215)
authorHenri Vasserman <redacted>
Fri, 28 Apr 2023 23:31:56 +0000 (02:31 +0300)
committerGitHub <redacted>
Fri, 28 Apr 2023 23:31:56 +0000 (01:31 +0200)
* Cuda: non-contiguous tensor support

* remove extra stuff

* rename

* fix error

* more fixes, now OpenBLAS and CLBlast build too

* now then?

ggml-cuda.cu
ggml-cuda.h
ggml.c

index d619f5da47b6349867d003b7647611d04a9982cc..eb244f409aafdc8f9916cb838fadda89b0e10e50 100644 (file)
@@ -302,3 +302,31 @@ void ggml_init_cublas(void) {
         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
     }
 }
+
+cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
+    const uint64_t ne0 = src->ne[0];
+    const uint64_t ne1 = src->ne[1];
+    const uint64_t nb0 = src->nb[0];
+    const uint64_t nb1 = src->nb[1];
+    const uint64_t nb2 = src->nb[2];
+    const uint64_t nb3 = src->nb[3];
+    const enum ggml_type type = src->type;
+    const size_t ts = ggml_type_size(type);
+    const size_t bs = ggml_blck_size(type);
+
+    const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
+    if (nb0 == ts && nb1 == ts*ne0/bs) {
+        return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
+    } else if (nb0 == ts) {
+        return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
+    } else {
+        for (uint64_t i1 = 0; i1 < ne1; i1++) {
+            const void * rx = (const void *) ((const char *) x + i1*nb1);
+            void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
+            // pretend the row is a matrix with cols=1
+            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
+            if (r != cudaSuccess) return r;
+        }
+        return cudaSuccess;
+    }
+}
index b105ed0c2fa2975fb3e89b6a73957d61c8c7484d..1fd67ebeb71cc8ee01ae33c527ca2c8f9c0db768 100644 (file)
@@ -1,5 +1,6 @@
 #include <cublas_v2.h>
 #include <cuda_runtime.h>
+#include "ggml.h"
 
 #ifdef  __cplusplus
 extern "C" {
@@ -38,6 +39,8 @@ void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t st
 void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 
+cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
+
 #ifdef  __cplusplus
 }
 #endif
diff --git a/ggml.c b/ggml.c
index 0c6eb7482a045df11856fafad1dc58dafba203db..4ec637ee1e0821eab7e5e3df2d951402557092dd 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -7930,8 +7930,12 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int64_t ne1 = dst->ne[1];
 
     // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
-        ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
+    if (
+#if !defined(GGML_USE_CUBLAS)
+        ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) &&
+#endif
+        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
 
         /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
         return true;
@@ -8041,15 +8045,16 @@ static void ggml_compute_forward_mul_mat_f32(
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
+#if !defined(GGML_USE_CUBLAS)
                 const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-
+#endif
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
 #if defined(GGML_USE_CUBLAS)
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
 
                 // compute
                 CUBLAS_CHECK(
@@ -8269,13 +8274,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 #endif
 
 #if defined(GGML_USE_CUBLAS)
-                const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
                 const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
 
                 // compute
@@ -8539,9 +8543,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
 #if defined(GGML_USE_CUBLAS)
                 // copy and dequantize on device
-                CUDA_CHECK(
-                    cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
-                        GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream));
 
                 dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
                 CUDA_CHECK(cudaGetLastError());
@@ -8561,7 +8563,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
 #if defined(GGML_USE_CUBLAS)
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
 
                 // compute
                 CUBLAS_CHECK(