]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
llama : support RWKV v6 models (llama/8980)
authorMolly Sophia <redacted>
Sun, 1 Sep 2024 14:38:17 +0000 (22:38 +0800)
committerGeorgi Gerganov <redacted>
Sun, 8 Sep 2024 11:43:07 +0000 (14:43 +0300)
* convert_hf_to_gguf: Add support for RWKV v6

Signed-off-by: Molly Sophia <redacted>
* Add RWKV tokenization

* Fix build

Signed-off-by: Molly Sophia <redacted>
* Do not use special tokens when matching in RWKV tokenizer

* Fix model loading

* Add (broken) placeholder graph builder for RWKV

* Add workaround for kv cache

* Add logits conversion to rwkv5

* Add rwkv5 layer norms

* Add time mix KVRG & correct merge mistake

* Add remaining time mix parameters

* Add time mix output loading

* Add placeholder llm_build_time_mix

* Fix build

Signed-off-by: Molly Sophia <redacted>
* Load more tensors for rwkv v6

Signed-off-by: Molly Sophia <redacted>
* Fix rwkv tokenizer

Signed-off-by: Molly Sophia <redacted>
* ggml: Add unary operator Exp

Signed-off-by: Molly Sophia <redacted>
* RWKV v6 graph building

Signed-off-by: Molly Sophia <redacted>
* Add ``rescale_every_n_layers`` parameter

Signed-off-by: Molly Sophia <redacted>
* Add ``wkv.head_size`` key for RWKV

so it doesn't reuse Mamba ssm parameters

Signed-off-by: Molly Sophia <redacted>
* Fix offloading layers to CUDA

Signed-off-by: Molly Sophia <redacted>
* Fix parallel inferencing for RWKV

Signed-off-by: Molly Sophia <redacted>
* Remove trailing whitespaces

Signed-off-by: Molly Sophia <redacted>
* build_rwkv: Avoid using inplace operations

Signed-off-by: Molly Sophia <redacted>
* convert_hf_to_gguf: rwkv: Avoid using ``eval``

Signed-off-by: Molly Sophia <redacted>
* convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually

Signed-off-by: Molly Sophia <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: compilade <redacted>
* ggml: Add backward computation for unary op ``exp``

Signed-off-by: Molly Sophia <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: compilade <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: compilade <redacted>
* Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV

Signed-off-by: Molly Sophia <redacted>
* build_rwkv6: Simplify graph

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Detect model.type

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Fix tensor loading for 7B/14B models

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Fix group_norm assertion failure with Metal

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Clean up

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Add quantization tensor exclusion

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Use the new advanced batch splits

Signed-off-by: Molly Sophia <redacted>
* Update src/llama.cpp

Co-authored-by: compilade <redacted>
* llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm``

Co-authored-by: compilade <redacted>
* llama: rwkv6: Apply code style and misc changes

Signed-off-by: Molly Sophia <redacted>
* converter: Use class name ``Rwkv6Model``

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Make use of key ``feed_forward_length``

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim``

Signed-off-by: Molly Sophia <redacted>
* converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Keep ``time_mix_w1/w2`` as F32

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Remove unused nodes

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Apply code format changes

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Add lora for some supported tensors

Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight

Signed-off-by: Molly Sophia <redacted>
* rwkv : speed-up tokenization using trie

* minor : style + indentation

* llama: rwkv6: Avoid division by zero

Co-authored-by: compilade <redacted>
* ggml: rwkv_wkv: Avoid copying the state

Signed-off-by: Molly Sophia <redacted>
---------

Signed-off-by: Molly Sophia <redacted>
Co-authored-by: Layl Bongers <redacted>
Co-authored-by: compilade <redacted>
Co-authored-by: Georgi Gerganov <redacted>
include/ggml.h
src/ggml.c

index b231f6ced23e5d6197359d0a5a7564b224892148..5021ecfd8b305c6368789eefce37cb0f673a4232 100644 (file)
@@ -514,6 +514,7 @@ extern "C" {
         GGML_OP_WIN_UNPART,
         GGML_OP_GET_REL_POS,
         GGML_OP_ADD_REL_POS,
+        GGML_OP_RWKV_WKV,
 
         GGML_OP_UNARY,
 
@@ -548,6 +549,7 @@ extern "C" {
         GGML_UNARY_OP_SILU,
         GGML_UNARY_OP_HARDSWISH,
         GGML_UNARY_OP_HARDSIGMOID,
+        GGML_UNARY_OP_EXP,
 
         GGML_UNARY_OP_COUNT,
     };
@@ -1165,6 +1167,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_exp(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_exp_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // normalize along rows
     GGML_API struct ggml_tensor * ggml_norm(
             struct ggml_context * ctx,
@@ -1913,6 +1923,15 @@ extern "C" {
             struct ggml_tensor  * pw,
             struct ggml_tensor  * ph);
 
+    GGML_API struct ggml_tensor * ggml_rwkv_wkv(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * r,
+            struct ggml_tensor  * tf,
+            struct ggml_tensor  * td,
+            struct ggml_tensor  * state);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
index 948f9ad0fd73b9e5d5dfb343e4a0fba2fb3d5242..a80edf6f7afa8b8dba574774686893dfd9a00013 100644 (file)
@@ -2422,6 +2422,7 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x
 // TODO: optimize performance
 inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
 inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
 
 static const float GELU_COEF_A     = 0.044715f;
 static const float GELU_QUICK_COEF = -1.702f;
@@ -2932,6 +2933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "WIN_UNPART",
     "GET_REL_POS",
     "ADD_REL_POS",
+    "RWKV_WKV",
 
     "UNARY",
 
@@ -2950,7 +2952,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3024,6 +3026,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "win_unpart(x)",
     "get_rel_pos(x)",
     "add_rel_pos(x)",
+    "rwkv_wkv(k, v, r, tf, td, s)",
 
     "unary(x)",
 
@@ -3042,7 +3045,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -3061,9 +3064,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "SILU",
     "HARDSWISH",
     "HARDSIGMOID",
+    "EXP",
 };
 
-static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
+static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
 
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -5466,6 +5470,19 @@ struct ggml_tensor * ggml_hardsigmoid(
     return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
 }
 
+// ggml exp
+struct ggml_tensor * ggml_exp(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
+}
+
+struct ggml_tensor * ggml_exp_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
+}
+
 // ggml_norm
 
 static struct ggml_tensor * ggml_norm_impl(
@@ -7734,6 +7751,59 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
     return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
 }
 
+// ggml_rwkv_wkv
+
+struct ggml_tensor * ggml_rwkv_wkv(
+        struct ggml_context * ctx,
+        struct ggml_tensor * k,
+        struct ggml_tensor * v,
+        struct ggml_tensor * r,
+        struct ggml_tensor * tf,
+        struct ggml_tensor * td,
+        struct ggml_tensor * state) {
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(r));
+    GGML_ASSERT(ggml_is_contiguous(tf));
+    GGML_ASSERT(ggml_is_contiguous(td));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S = k->ne[0];
+    const int64_t H = k->ne[2];
+    const int64_t n_tokens = k->ne[3];
+    const int64_t n_seqs = state->ne[1];
+    {
+        GGML_ASSERT(k->ne[1] == 1);
+        GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
+        GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
+        // TODO: RWKV v4 and v5
+        GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    }
+
+    bool is_node = false;
+
+    if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
+        GGML_ABORT("fatal error"); // TODO: implement backward
+        is_node = true;
+    }
+
+    // concat output and new_state
+    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op   = GGML_OP_RWKV_WKV;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = k;
+    result->src[1] = v;
+    result->src[2] = r;
+    result->src[3] = tf;
+    result->src[4] = td;
+    result->src[5] = state;
+
+    return result;
+}
+
 // ggml_unary
 
 static struct ggml_tensor * ggml_unary_impl(
@@ -12114,6 +12184,48 @@ static void ggml_compute_forward_hardsigmoid(
     }
 }
 
+static void ggml_compute_forward_exp_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_exp_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_exp(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_exp_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 
 // ggml_compute_forward_norm
 
@@ -16692,6 +16804,10 @@ static void ggml_compute_forward_unary(
             {
                 ggml_compute_forward_hardsigmoid(params, dst);
             } break;
+        case GGML_UNARY_OP_EXP:
+            {
+                ggml_compute_forward_exp(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -16827,6 +16943,96 @@ static void ggml_compute_forward_add_rel_pos(
     }
 }
 
+// ggml_compute_forward_rwkv_wkv
+
+static void ggml_compute_forward_rwkv_wkv_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const size_t T = dst->src[1]->ne[3];
+    const size_t C = dst->ne[0];
+    const size_t H = dst->src[1]->ne[2];
+    const size_t n_seqs = dst->src[5]->ne[1];
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    memset(dst_data, 0, T * C * sizeof(float));
+
+    float * k =          (float *) dst->src[0]->data;
+    float * v =          (float *) dst->src[1]->data;
+    float * r =          (float *) dst->src[2]->data;
+    float * time_faaaa = (float *) dst->src[3]->data;
+    float * time_decay = (float *) dst->src[4]->data;
+
+    size_t t_stride = H * (C / H);
+
+    size_t h_stride = C / H;
+    size_t h_stride_2d = (C / H) * (C / H);
+
+    // basically fused operations:
+    // dst = r @ (time_faaaa * (k @ v) + state),
+    // state = time_decay * state + (k @ v),
+    // recursive through each token
+    for (size_t t = 0; t < T; t++) {
+        size_t t_offset = t * t_stride;
+        size_t state_offset = (C / H) * C * (t / (T / n_seqs));
+        float * state_cur = state + state_offset;
+        float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+        for (size_t h = 0; h < H; h++) {
+            size_t h_offset = h * h_stride;
+            size_t t_h_offset = t_offset + h_offset;
+            size_t h_2d_offset = h * h_stride_2d;
+
+            for (size_t i = 0; i < C / H; i++) {
+                size_t t_h_i_offset = t_h_offset + i;
+                size_t h_i_offset = h_offset + i;
+                size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                float k_val = k[t_h_i_offset];
+                float r_val = r[t_h_i_offset];
+                float time_faaaa_val = time_faaaa[h_i_offset];
+                // RWKV v6: different time_decay for each token.
+                float time_decay_val = time_decay[t_h_i_offset];
+
+                for (size_t j = 0; j < C / H; j ++) {
+                    size_t t_h_j_offset = t_h_offset + j;
+                    size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                    float v_val = v[t_h_j_offset];
+                    float kv_val = v_val * k_val;
+                    float prev_state_val = state_prev[h_2d_i_j_offset];
+                    float temp_val = kv_val * time_faaaa_val + prev_state_val;
+                    dst_data[t_h_j_offset] += temp_val * r_val;
+                    state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rwkv_wkv(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rwkv_wkv_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -17478,6 +17684,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_add_rel_pos(params, tensor);
             } break;
+        case GGML_OP_RWKV_WKV:
+            {
+                ggml_compute_forward_rwkv_wkv(params, tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -18591,12 +18801,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                         zero_table);
                             }
                         } break;
+                    case GGML_UNARY_OP_EXP:
+                        {
+                            if (src0->grad) {
+                                src0->grad = ggml_add_or_set(ctx,
+                                        src0->grad,
+                                        ggml_mul(ctx, tensor, tensor->grad),
+                                        zero_table);
+                            }
+                        } break;
                     default:
                         GGML_ABORT("fatal error");
                 }
             } break;
         case GGML_OP_GET_REL_POS:
         case GGML_OP_ADD_REL_POS:
+        case GGML_OP_RWKV_WKV:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
@@ -19021,6 +19241,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 case GGML_UNARY_OP_SIGMOID:
                 case GGML_UNARY_OP_HARDSWISH:
                 case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_EXP:
                     {
                         n_tasks = 1;
                     } break;
@@ -19112,6 +19333,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
+        case GGML_OP_RWKV_WKV:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32: