]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add ReLU and SQR CUDA ops to (partially) fix Persimmon offloading (#4041)
authorKerfuffle <redacted>
Mon, 13 Nov 2023 08:58:15 +0000 (01:58 -0700)
committerGitHub <redacted>
Mon, 13 Nov 2023 08:58:15 +0000 (01:58 -0700)
* Add ReLU and SQR CUDA ops to fix Persimmon offloading

* Persimmon loader: More helpful error on CUDA/ROCM when offloading too many layers

ggml-cuda.cu
llama.cpp

index f87f18802c8f870f5570780dd412df5ea7f7a1aa..8d03ba6641981fe8b4b927302c71335083fba0f4 100644 (file)
@@ -433,6 +433,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_MUL_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_SQR_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
 #define CUDA_CLAMP_BLOCK_SIZE 256
@@ -553,6 +555,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void relu_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = fmaxf(x[i], 0);
+}
+
+static __global__ void sqr_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * x[i];
+}
+
 static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
@@ -4759,6 +4779,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+    relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
+    sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
@@ -6128,6 +6158,34 @@ inline void ggml_cuda_op_silu(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_relu(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_sqr(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7160,6 +7218,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
 }
 
+static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
+}
+
+static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
+}
+
 static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
 }
@@ -7891,6 +7957,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
                 case GGML_UNARY_OP_SILU:
                     func = ggml_cuda_silu;
                     break;
+                case GGML_UNARY_OP_RELU:
+                    func = ggml_cuda_relu;
+                    break;
                 default:
                     return false;
             } break;
@@ -7909,6 +7978,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_SCALE:
             func = ggml_cuda_scale;
             break;
+        case GGML_OP_SQR:
+            func = ggml_cuda_sqr;
+            break;
         case GGML_OP_CLAMP:
             if (!any_on_device) {
                 return false;
index d682d2864d283677df636b17e05c118ddfe06f43..a5f3876cc19e0c6e471d4d2d329fb471e8a99a8a 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2877,6 +2877,13 @@ static void llm_load_tensors(
                         ggml_backend_type backend_output;
 
                         if (n_gpu_layers > int(n_layer)) {
+#ifdef GGML_USE_CUBLAS
+                            if (n_gpu_layers > int(n_layer + 1)) {
+                                LLAMA_LOG_ERROR("%s: CUDA backend missing Persimmon CUDA ops, can offload at most %ld layers. See: https://github.com/ggerganov/llama.cpp/issues/4038\n",
+                                    __func__, n_layer + 1);
+                                throw std::runtime_error("Persimmon CUDA offload failed");
+                            }
+#endif
                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
                             // on Windows however this is detrimental unless everything is on the GPU
 #ifndef _WIN32