]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : add batched cuBLAS GEMM for faster attention (#3749)
authorGeorgi Gerganov <redacted>
Tue, 24 Oct 2023 13:48:37 +0000 (16:48 +0300)
committerGitHub <redacted>
Tue, 24 Oct 2023 13:48:37 +0000 (16:48 +0300)
* cmake : add helper for faster CUDA builds

* batched : add NGL arg

* ggml : skip nops in compute_forward

* cuda : minor indentation

* cuda : batched cuBLAS GEMMs for src0 F16 and src1 F32 (attention ops)

* Apply suggestions from code review

These changes plus:

```c++
#define cublasGemmBatchedEx hipblasGemmBatchedEx
```

are needed to compile with ROCM. I haven't done performance testing, but it seems to work.

I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up.

* cuda : add ROCm / hipBLAS cublasGemmBatchedEx define

* cuda : add cublasGemmStridedBatchedEx for non-broadcasted cases

* cuda : reduce mallocs in cublasGemmBatchedEx branch

* cuda : add TODO for calling cublas from kernel + using mem pool

---------

Co-authored-by: Kerfuffle <redacted>
CMakeLists.txt
examples/batched/batched.cpp
ggml-cuda.cu
ggml.c

index 6af42a6c266c830cc2f3dd2f52c266e80323bccc..202f260491d3938338a2b575f0b1e0dc7261067c 100644 (file)
@@ -331,6 +331,7 @@ if (LLAMA_CUBLAS)
             set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
         else()
             set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
+            #set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work
         endif()
     endif()
     message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
index 75856a81fe9b1b9e50448328dfda5ac41c5863c1..22a4265df77c09b2c393677711c061e8feaa1f18 100644 (file)
@@ -11,7 +11,7 @@ int main(int argc, char ** argv) {
     gpt_params params;
 
     if (argc == 1 || argv[1][0] == '-') {
-        printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN]\n" , argv[0]);
+        printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]);
         return 1 ;
     }
 
@@ -21,6 +21,9 @@ int main(int argc, char ** argv) {
     // total length of the sequences including the prompt
     int n_len = 32;
 
+    // number of layers to offload to the GPU
+    int n_gpu_layers = 0;
+
     if (argc >= 2) {
         params.model = argv[1];
     }
@@ -37,6 +40,10 @@ int main(int argc, char ** argv) {
         n_len = std::atoi(argv[4]);
     }
 
+    if (argc >= 6) {
+        n_gpu_layers = std::atoi(argv[5]);
+    }
+
     if (params.prompt.empty()) {
         params.prompt = "Hello my name is";
     }
@@ -49,7 +56,7 @@ int main(int argc, char ** argv) {
 
     llama_model_params model_params = llama_model_default_params();
 
-    // model_params.n_gpu_layers = 99; // offload all layers to the GPU
+    model_params.n_gpu_layers = n_gpu_layers;
 
     llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
 
index 654d3632fc179a3608b47040e78b9886f403172b..db053e3b8a9d81d4e7f05e12ec817c762d43d6d9 100644 (file)
@@ -29,6 +29,8 @@
 #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
 #define cublasCreate hipblasCreate
 #define cublasGemmEx hipblasGemmEx
+#define cublasGemmBatchedEx hipblasGemmBatchedEx
+#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
 #define cublasHandle_t hipblasHandle_t
 #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
 #define cublasSetStream hipblasSetStream
@@ -4326,13 +4328,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
 
     const half * x = (const half *) vx;
 
-    const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
-    const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+    const int row_x     = blockDim.y*blockIdx.y + threadIdx.y;
+    const int channel   = blockDim.z*blockIdx.z + threadIdx.z;
     const int channel_x = channel / channel_x_divisor;
 
-    const int nrows_y = ncols_x;
+    const int nrows_y   = ncols_x;
     const int nrows_dst = nrows_x;
-    const int row_dst = row_x;
+    const int row_dst   = row_x;
 
     const int idst = channel*nrows_dst + row_dst;
 
@@ -4345,13 +4347,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
             break;
         }
 
-        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
-        const float xi = __half2float(x[ix]);
-
         const int row_y = col_x;
 
+        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
         const int iy = channel*nrows_y + row_y;
 
+        const float xi = __half2float(x[ix]);
+
         tmp += xi * y[iy];
     }
 
@@ -7013,7 +7015,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
 }
 
 static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
-    GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
+    GGML_ASSERT(!ggml_is_transposed(src0));
+    GGML_ASSERT(!ggml_is_transposed(src1));
     GGML_ASSERT(!ggml_is_permuted(src0));
     GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -7023,11 +7026,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
 
-    const int64_t ne12 = src1->ne[2];
-
     const int64_t nb01 = src0->nb[1];
     const int64_t nb02 = src0->nb[2];
 
+    const int64_t ne12 = src1->ne[2];
+
     CUDA_CHECK(ggml_cuda_set_device(g_main_device));
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
@@ -7046,6 +7049,159 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
 }
 
+static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+    GGML_ASSERT(!ggml_is_transposed(src0));
+    GGML_ASSERT(!ggml_is_transposed(src1));
+    GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
+    const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
+    const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
+
+    const int64_t ne1 = ggml_nelements(src1);
+    const int64_t ne  = ggml_nelements(dst);
+
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
+
+    ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    void * src0_ddq = src0_extra->data_device[g_main_device];
+    half * src0_as_f16 = (half *) src0_ddq;
+
+    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
+
+    ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+    float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+    // convert src1 to fp16
+    const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+    GGML_ASSERT(to_fp16_cuda != nullptr);
+
+    size_t src1_as = 0;
+    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
+    to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
+
+    size_t dst_as = 0;
+    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+
+    GGML_ASSERT(ne12 % ne02 == 0);
+    GGML_ASSERT(ne13 % ne03 == 0);
+
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
+    const half alpha_f16 = 1.0f;
+    const half beta_f16  = 0.0f;
+
+#if 0
+    // use cublasGemmEx
+    {
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                int i03 = i13 / r3;
+                int i02 = i12 / r2;
+
+                CUBLAS_CHECK(
+                        cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]  , CUDA_R_16F, nb01/sizeof(half),
+                                        (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
+                            &beta_f16,  (      char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
+                            CUBLAS_COMPUTE_16F,
+                            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+            }
+        }
+    }
+#else
+    if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
+        // there is no broadcast and src0, src1 are contiguous across dims 2, 3
+        // use cublasGemmStridedBatchedEx
+        CUBLAS_CHECK(
+        cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                ne01, ne11, ne10,
+                &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA
+                            (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
+                &beta_f16,  (      char *)     dst_f16, CUDA_R_16F, ne01,                dst->nb[2]/sizeof(float), // strideC
+                ne12*ne13,
+                CUBLAS_COMPUTE_16F,
+                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+    } else {
+        // use cublasGemmBatchedEx
+        // TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
+        const int ne23 = ne12*ne13;
+
+        // TODO: avoid this alloc
+        void ** ptrs = (void **) malloc(3*ne23*sizeof(void *));
+
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                int i03 = i13 / r3;
+                int i02 = i12 / r2;
+
+                ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3];
+                ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
+                ptrs[2*ne23 + i12 + i13*ne12] = (char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
+            }
+        }
+
+        // allocate device memory for pointers
+        void ** ptrs_as = nullptr;
+        CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *)));
+
+        // TODO: this does not work for some reason -- not sure why?
+        //size_t ptrs_s = 0;
+        //ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
+
+        // copy pointers to device
+        CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice));
+
+        free(ptrs);
+
+        CUBLAS_CHECK(
+        cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                ne01, ne11, ne10,
+                &alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
+                            (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
+                &beta_f16,  (      void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
+                ne23,
+                CUBLAS_COMPUTE_16F,
+                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+        // free device memory for pointers
+        CUDA_CHECK(cudaFree(ptrs_as));
+        //ggml_cuda_pool_free(ptrs_as, ptrs_s);
+    }
+#endif
+
+    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+    to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+
+    ggml_cuda_pool_free(src1_as_f16, src1_as);
+    ggml_cuda_pool_free(dst_f16, dst_as);
+}
+
 static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
         src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
@@ -7058,10 +7214,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
         }
     }
 
+    // debug helpers
+    //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
+    //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
+    //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
+    //printf("      %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
+    //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
+    //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
+
     if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+        // KQ
         ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
-    } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
+    } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+        // KQV
         ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
+    } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+        ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
     } else if (src0->type == GGML_TYPE_F32) {
         ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
     } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
diff --git a/ggml.c b/ggml.c
index 49f3b7aba31f5651719b4da4dddf43b1ab916efd..17f0ce4877592313da8169986cc658657cc5168c 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -16602,6 +16602,10 @@ static void ggml_compute_forward_cross_entropy_loss_back(
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
     GGML_ASSERT(params);
 
+    if (tensor->op == GGML_OP_NONE) {
+        return;
+    }
+
 #ifdef GGML_USE_CUBLAS
     bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
     if (skip_cpu) {