]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
RWKV v6: RWKV_WKV op CUDA implementation (#9454)
authorMolly Sophia <redacted>
Sun, 22 Sep 2024 02:29:12 +0000 (10:29 +0800)
committerGitHub <redacted>
Sun, 22 Sep 2024 02:29:12 +0000 (04:29 +0200)
* ggml: CUDA unary op EXP

Signed-off-by: Molly Sophia <redacted>
* ggml: rwkv_wkv op CUDA impl

Signed-off-by: Molly Sophia <redacted>
---------

Signed-off-by: Molly Sophia <redacted>
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/rwkv-wkv.cu [new file with mode: 0644]
ggml/src/ggml-cuda/rwkv-wkv.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/unary.cu
ggml/src/ggml-cuda/unary.cuh
tests/test-backend-ops.cpp

index 895ba479483e71baf6f806270a7068ae1fc99982..f940511980e59fe23ec05b5eb070fdb7133becb9 100644 (file)
@@ -34,6 +34,7 @@
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
+#include "ggml-cuda/rwkv-wkv.cuh"
 
 #include <algorithm>
 #include <array>
@@ -2243,6 +2244,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_UNARY_OP_HARDSWISH:
                     ggml_cuda_op_hardswish(ctx, dst);
                     break;
+                case GGML_UNARY_OP_EXP:
+                    ggml_cuda_op_exp(ctx, dst);
+                    break;
                 default:
                     return false;
             }
@@ -2345,6 +2349,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_CROSS_ENTROPY_LOSS:
             ggml_cuda_cross_entropy_loss(ctx, dst);
             break;
+        case GGML_OP_RWKV_WKV:
+            ggml_cuda_op_rwkv_wkv(ctx, dst);
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
             break;
@@ -2806,6 +2812,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 case GGML_UNARY_OP_HARDSWISH:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_TANH:
+                case GGML_UNARY_OP_EXP:
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
@@ -2967,6 +2974,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
+        case GGML_OP_RWKV_WKV:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cu b/ggml/src/ggml-cuda/rwkv-wkv.cu
new file mode 100644 (file)
index 0000000..098e92d
--- /dev/null
@@ -0,0 +1,89 @@
+#include "common.cuh"
+#include "rwkv-wkv.cuh"
+
+static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = CUDA_WKV_BLOCK_SIZE;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    __syncthreads();
+    _tf[tid] = tf[head_i * head_size + tid];
+    __syncthreads();
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        __syncthreads();
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4& k = (float4&)(_k[j]);
+            const float4& r = (float4&)(_r[j]);
+            const float4& tf = (float4&)(_tf[j]);
+            const float4& td = (float4&)(_td[j]);
+            float4& s = (float4&)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            y += r.x * (tf.x * kv.x + s.x);
+            y += r.y * (tf.y * kv.y + s.y);
+            y += r.z * (tf.z * kv.z + s.z);
+            y += r.w * (tf.w * kv.w + s.w);
+
+            s.x = s.x * td.x + kv.x;
+            s.y = s.y * td.y + kv.y;
+            s.z = s.z * td.z + kv.z;
+            s.w = s.w * td.w + kv.w;
+        }
+        dst[t] = y;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * k_d  = (const float *)dst->src[0]->data;
+    const float * v_d  = (const float *)dst->src[1]->data;
+    const float * r_d  = (const float *)dst->src[2]->data;
+    const float * tf_d = (const float *)dst->src[3]->data;
+    const float * td_d = (const float *)dst->src[4]->data;
+    const float * s_d  = (const float *)dst->src[5]->data;
+
+    const int64_t B = dst->src[5]->ne[1];
+    const int64_t T = dst->src[0]->ne[3];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[2];
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
+
+    rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+}
diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cuh b/ggml/src/ggml-cuda/rwkv-wkv.cuh
new file mode 100644 (file)
index 0000000..1379524
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_WKV_BLOCK_SIZE 64
+
+void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 163b5a8ffec6b72f1ffac66a6b4c31f94e139808..81fc92202f25ae45a5b21669e1cb8835e33e787d 100644 (file)
@@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
     dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
 }
 
+static __global__ void exp_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = expf(x[i]);
+}
+
 static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
@@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
     hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
+    exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
     leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
 
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
 void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
index fe519f6a232dfa053bd704d6f3f7c0942eda5510..c91936728bab16645732822b8a5839b4667f6b76 100644 (file)
@@ -8,6 +8,7 @@
 #define CUDA_RELU_BLOCK_SIZE 256
 #define CUDA_SIGMOID_BLOCK_SIZE 256
 #define CUDA_HARDSIGMOID_BLOCK_SIZE 256
+#define CUDA_EXP_BLOCK_SIZE 256
 #define CUDA_HARDSWISH_BLOCK_SIZE 256
 #define CUDA_SQR_BLOCK_SIZE 256
 #define CUDA_SQRT_BLOCK_SIZE 256
@@ -32,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 889a199448a6d1ce77cdd3f58fd7e376a6a58e36..efa88688cde75fab5c631d5eebec1e0beb4037a4 100644 (file)
@@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
     }
 };
 
+// GGML_OP_RWKV_WKV
+struct test_rwkv_wkv : public test_case {
+    const ggml_type type;
+
+    const int64_t head_count;
+    const int64_t head_size;
+    const int64_t n_seq_tokens;
+    const int64_t n_seqs;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
+    }
+
+    test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
+            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
+        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const int64_t n_tokens = n_seq_tokens * n_seqs;
+        ggml_tensor * r   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
+        ggml_tensor * k   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
+        ggml_tensor * v   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
+        ggml_tensor * tf  = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
+        ggml_tensor * td  = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
+        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
+        ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
+        return out;
+    }
+};
+
 // GGML_OP_MUL_MAT
 struct test_mul_mat : public test_case {
     const ggml_type type_a;
@@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
 
+    test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
+    test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
+    test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
+    test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
+
 #if 1
     for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {