]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : add RoPE kernel for mode == 2 (NeoX) (#2760)
authorGeorgi Gerganov <redacted>
Fri, 25 Aug 2023 08:55:59 +0000 (11:55 +0300)
committerGitHub <redacted>
Fri, 25 Aug 2023 08:55:59 +0000 (11:55 +0300)
* cuda : add RoPE kernel for mode == 2 (NeoX)

* falcon : do not offload the embeddings layer

ggml-cuda.cu
llama.cpp

index 868b7a7b905a24ea7ea27a856e3834ee6a6c74a8..3bd1caf23f86d674e6e55675164733774583be33 100644 (file)
@@ -3907,28 +3907,27 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
-// TODO: this implementation is wrong!
-//static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
-//                                const float p_delta, const int p_delta_rows, const float theta_scale) {
-//    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
-//
-//    if (col >= ncols) {
-//        return;
-//    }
-//
-//    const int row = blockDim.x*blockIdx.x + threadIdx.x;
-//    const int i = row*ncols + col/2;
-//
-//    const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
-//    const float sin_theta = sinf(theta);
-//    const float cos_theta = cosf(theta);
-//
-//    const float x0 = x[i + 0];
-//    const float x1 = x[i + ncols/2];
-//
-//    dst[i + 0]       = x0*cos_theta - x1*sin_theta;
-//    dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
-//}
+static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
+                                const float p_delta, const int p_delta_rows, const float theta_scale) {
+    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int i = row*ncols + col/2;
+
+    const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
+    const float sin_theta = sinf(theta);
+    const float cos_theta = cosf(theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + ncols/2];
+
+    dst[i + 0]       = x0*cos_theta - x1*sin_theta;
+    dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
+}
 
 static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
@@ -4799,13 +4798,21 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
 
 static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
                           const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
-    GGML_ASSERT(nrows % 2 == 0);
+    GGML_ASSERT(nrows % 2 == 0); // GG: is this assert really needed? I don't see why
     const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
 }
 
+static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
+                          const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+    const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
+    const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nrows, num_blocks_x, 1);
+    rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
+}
+
 static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
     GGML_ASSERT(nrows % 4 == 0);
     const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
@@ -5548,8 +5555,9 @@ inline void ggml_cuda_op_rope(
         const float block_p = max(p - (n_ctx - 2.f), 0.f);
         rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
     } else if (is_neox) {
-        GGML_ASSERT(false && "RoPE NeoX not implemented yet");
-#pragma message("TODO: implement RoPE NeoX for CUDA")
+        GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
+        const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
+        rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
     } else {
         const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
index 67319396e104b6f1cf71e2da3cbd24746581827d..52ba31d79bc1ea2686010a2bf88c433d1d72237e 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1958,6 +1958,14 @@ static void llm_load_tensors(
                         model.output_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm);
                         model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd},          backend_norm);
                         model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
+
+                        if (backend_norm == GGML_BACKEND_GPU) {
+                            vram_weights += ggml_nbytes(model.output_norm);
+                            vram_weights += ggml_nbytes(model.output_norm_b);
+                        }
+                        if (backend_output == GGML_BACKEND_GPU_SPLIT) {
+                            vram_weights += ggml_nbytes(model.output);
+                        }
                     }
 
                     const uint32_t n_ff = hparams.n_ff;
@@ -1967,7 +1975,7 @@ static void llm_load_tensors(
                     model.layers.resize(n_layer);
 
                     for (uint32_t i = 0; i < n_layer; ++i) {
-                        const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
+                        const ggml_backend backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
                         const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
 
                         auto & layer = model.layers[i];
@@ -1978,6 +1986,11 @@ static void llm_load_tensors(
                         if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) {
                             layer.attn_norm_2   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend);
                             layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, backend);
+
+                            if (backend == GGML_BACKEND_GPU) {
+                                vram_weights += ggml_nbytes(layer.attn_norm_2);
+                                vram_weights += ggml_nbytes(layer.attn_norm_2_b);
+                            }
                         }
 
                         layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
@@ -1985,6 +1998,13 @@ static void llm_load_tensors(
 
                         layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
                         layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+
+                        if (backend == GGML_BACKEND_GPU) {
+                            vram_weights +=
+                                ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
+                                ggml_nbytes(layer.wqkv)      + ggml_nbytes(layer.wo)          +
+                                ggml_nbytes(layer.w2)        + ggml_nbytes(layer.w3);
+                        }
                     }
                 } break;
             default: