]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support RWKV v6 models (#8980)
authorMolly Sophia <redacted>
Sun, 1 Sep 2024 14:38:17 +0000 (22:38 +0800)
committerGitHub <redacted>
Sun, 1 Sep 2024 14:38:17 +0000 (17:38 +0300)
* convert_hf_to_gguf: Add support for RWKV v6

Signed-off-by: Molly Sophia <redacted>
* Add RWKV tokenization

* Fix build

Signed-off-by: Molly Sophia <redacted>
* Do not use special tokens when matching in RWKV tokenizer

* Fix model loading

* Add (broken) placeholder graph builder for RWKV

* Add workaround for kv cache

* Add logits conversion to rwkv5

* Add rwkv5 layer norms

* Add time mix KVRG & correct merge mistake

* Add remaining time mix parameters

* Add time mix output loading

* Add placeholder llm_build_time_mix

* Fix build

Signed-off-by: Molly Sophia <redacted>
* Load more tensors for rwkv v6

Signed-off-by: Molly Sophia <redacted>
* Fix rwkv tokenizer

Signed-off-by: Molly Sophia <redacted>
* ggml: Add unary operator Exp

Signed-off-by: Molly Sophia <redacted>
* RWKV v6 graph building

Signed-off-by: Molly Sophia <redacted>
* Add ``rescale_every_n_layers`` parameter

Signed-off-by: Molly Sophia <redacted>
* Add ``wkv.head_size`` key for RWKV

so it doesn't reuse Mamba ssm parameters

Signed-off-by: Molly Sophia <redacted>
* Fix offloading layers to CUDA

Signed-off-by: Molly Sophia <redacted>
* Fix parallel inferencing for RWKV

Signed-off-by: Molly Sophia <redacted>
* Remove trailing whitespaces

Signed-off-by: Molly Sophia <redacted>
* build_rwkv: Avoid using inplace operations

Signed-off-by: Molly Sophia <redacted>
* convert_hf_to_gguf: rwkv: Avoid using ``eval``

Signed-off-by: Molly Sophia <redacted>
* convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually

Signed-off-by: Molly Sophia <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: compilade <redacted>
* ggml: Add backward computation for unary op ``exp``

Signed-off-by: Molly Sophia <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: compilade <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: compilade <redacted>
* Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV

Signed-off-by: Molly Sophia <redacted>
* build_rwkv6: Simplify graph

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Detect model.type

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Fix tensor loading for 7B/14B models

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Fix group_norm assertion failure with Metal

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Clean up

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Add quantization tensor exclusion

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Use the new advanced batch splits

Signed-off-by: Molly Sophia <redacted>
* Update src/llama.cpp

Co-authored-by: compilade <redacted>
* llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm``

Co-authored-by: compilade <redacted>
* llama: rwkv6: Apply code style and misc changes

Signed-off-by: Molly Sophia <redacted>
* converter: Use class name ``Rwkv6Model``

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Make use of key ``feed_forward_length``

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim``

Signed-off-by: Molly Sophia <redacted>
* converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Keep ``time_mix_w1/w2`` as F32

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Remove unused nodes

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Apply code format changes

Signed-off-by: Molly Sophia <redacted>
* llama: rwkv6: Add lora for some supported tensors

Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight

Signed-off-by: Molly Sophia <redacted>
* rwkv : speed-up tokenization using trie

* minor : style + indentation

* llama: rwkv6: Avoid division by zero

Co-authored-by: compilade <redacted>
* ggml: rwkv_wkv: Avoid copying the state

Signed-off-by: Molly Sophia <redacted>
---------

Signed-off-by: Molly Sophia <redacted>
Co-authored-by: Layl Bongers <redacted>
Co-authored-by: compilade <redacted>
Co-authored-by: Georgi Gerganov <redacted>
convert_hf_to_gguf.py
ggml/include/ggml.h
ggml/src/ggml.c
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
include/llama.h
src/llama-vocab.cpp
src/llama.cpp

index caa41aee5f30be7f63178aa4d6aa3ac0d2b0b11f..27ac34b810acd3e3fc56bb386529561e94bd70d0 100755 (executable)
@@ -3,6 +3,7 @@
 
 from __future__ import annotations
 
+import ast
 import logging
 import argparse
 import contextlib
@@ -298,9 +299,12 @@ class Model:
                             gguf.MODEL_TENSOR.POS_EMBD,
                             gguf.MODEL_TENSOR.TOKEN_TYPES,
                             gguf.MODEL_TENSOR.SSM_CONV1D,
+                            gguf.MODEL_TENSOR.TIME_MIX_FIRST,
+                            gguf.MODEL_TENSOR.TIME_MIX_W1,
+                            gguf.MODEL_TENSOR.TIME_MIX_W2,
                         )
                     )
-                    or not name.endswith(".weight")
+                    or not new_name.endswith(".weight")
                 ):
                     data_qtype = gguf.GGMLQuantizationType.F32
 
@@ -2716,6 +2720,84 @@ class StarCoder2Model(Model):
     model_arch = gguf.MODEL_ARCH.STARCODER2
 
 
+@Model.register("Rwkv6ForCausalLM")
+class Rwkv6Model(Model):
+    model_arch = gguf.MODEL_ARCH.RWKV6
+
+    def set_vocab(self):
+        assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
+        vocab_size = self.hparams.get("vocab_size", 65536)
+
+        tokens: list[bytes] = ['<s>'.encode("utf-8")]
+        toktypes: list[int] = [gguf.TokenType.CONTROL]
+
+        with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
+            lines = f.readlines()
+            for line in lines:
+                parts = line.split(' ')
+                assert len(parts) >= 3
+                token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
+                token = token.encode("utf-8") if isinstance(token, str) else token
+                assert isinstance(token, bytes)
+                assert len(token) == token_len
+                token_text: str = repr(token)[2:-1]  # "b'\xff'" -> "\xff"
+                tokens.append(token_text.encode("utf-8"))
+                toktypes.append(gguf.TokenType.NORMAL)
+        remainder = vocab_size - len(tokens)
+        assert remainder >= 0
+        for i in range(len(tokens), vocab_size):
+            tokens.append(f"[PAD{i}]".encode("utf-8"))
+            toktypes.append(gguf.TokenType.UNUSED)
+
+        self.gguf_writer.add_tokenizer_model("rwkv")
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_types(toktypes)
+
+    def set_gguf_parameters(self):
+        block_count = self.hparams["num_hidden_layers"]
+        head_size = self.hparams["head_size"]
+        hidden_size = self.hparams["hidden_size"]
+        layer_norm_eps = self.hparams["layer_norm_epsilon"]
+        rescale_every_n_layers = self.hparams["rescale_every"]
+        intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else int((hidden_size * 3.5) // 32 * 32)
+        time_mix_extra_dim = 64 if hidden_size == 4096 else 32
+        time_decay_extra_dim = 128 if hidden_size == 4096 else 64
+
+        # RWKV isn't context limited
+        self.gguf_writer.add_context_length(1048576)
+        self.gguf_writer.add_embedding_length(hidden_size)
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
+        self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
+        self.gguf_writer.add_wkv_head_size(head_size)
+        self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
+        self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
+        self.gguf_writer.add_feed_forward_length(intermediate_size)
+        self.gguf_writer.add_file_type(self.ftype)
+
+        # required by llama.cpp, unused
+        self.gguf_writer.add_head_count(0)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        new_name = self.map_tensor_name(name)
+
+        if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
+            new_name += ".weight"
+
+        if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"):
+            data_torch = data_torch.transpose(0, 1)
+
+        if new_name.endswith("time_mix_w2.weight"):
+            data_torch = data_torch.permute(0, 2, 1)
+
+        rescale_every_n_layers = self.hparams["rescale_every"]
+        if rescale_every_n_layers > 0:
+            if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
+                data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
+
+        yield (new_name, data_torch)
+
+
 @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
 class MambaModel(Model):
     model_arch = gguf.MODEL_ARCH.MAMBA
index 5233a9995b62991a7da8937aaa942a86a19cb231..3fb680360c3830156819445dd262663a6a5cab9f 100644 (file)
@@ -514,6 +514,7 @@ extern "C" {
         GGML_OP_WIN_UNPART,
         GGML_OP_GET_REL_POS,
         GGML_OP_ADD_REL_POS,
+        GGML_OP_RWKV_WKV,
 
         GGML_OP_UNARY,
 
@@ -548,6 +549,7 @@ extern "C" {
         GGML_UNARY_OP_SILU,
         GGML_UNARY_OP_HARDSWISH,
         GGML_UNARY_OP_HARDSIGMOID,
+        GGML_UNARY_OP_EXP,
 
         GGML_UNARY_OP_COUNT,
     };
@@ -1165,6 +1167,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_exp(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_exp_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // normalize along rows
     GGML_API struct ggml_tensor * ggml_norm(
             struct ggml_context * ctx,
@@ -1913,6 +1923,15 @@ extern "C" {
             struct ggml_tensor  * pw,
             struct ggml_tensor  * ph);
 
+    GGML_API struct ggml_tensor * ggml_rwkv_wkv(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * r,
+            struct ggml_tensor  * tf,
+            struct ggml_tensor  * td,
+            struct ggml_tensor  * state);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
index dc6cdca0bd8f65129afb8e663473b47b10cb9697..83140d757d0ce6e4c7a91dcb6064fd37a022c6a0 100644 (file)
@@ -2422,6 +2422,7 @@ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x
 // TODO: optimize performance
 inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
 inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
 
 static const float GELU_COEF_A     = 0.044715f;
 static const float GELU_QUICK_COEF = -1.702f;
@@ -2932,6 +2933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "WIN_UNPART",
     "GET_REL_POS",
     "ADD_REL_POS",
+    "RWKV_WKV",
 
     "UNARY",
 
@@ -2950,7 +2952,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3024,6 +3026,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "win_unpart(x)",
     "get_rel_pos(x)",
     "add_rel_pos(x)",
+    "rwkv_wkv(k, v, r, tf, td, s)",
 
     "unary(x)",
 
@@ -3042,7 +3045,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -3061,9 +3064,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "SILU",
     "HARDSWISH",
     "HARDSIGMOID",
+    "EXP",
 };
 
-static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
+static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
 
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -5464,6 +5468,19 @@ struct ggml_tensor * ggml_hardsigmoid(
     return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
 }
 
+// ggml exp
+struct ggml_tensor * ggml_exp(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
+}
+
+struct ggml_tensor * ggml_exp_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
+}
+
 // ggml_norm
 
 static struct ggml_tensor * ggml_norm_impl(
@@ -7727,6 +7744,59 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
     return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
 }
 
+// ggml_rwkv_wkv
+
+struct ggml_tensor * ggml_rwkv_wkv(
+        struct ggml_context * ctx,
+        struct ggml_tensor * k,
+        struct ggml_tensor * v,
+        struct ggml_tensor * r,
+        struct ggml_tensor * tf,
+        struct ggml_tensor * td,
+        struct ggml_tensor * state) {
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(r));
+    GGML_ASSERT(ggml_is_contiguous(tf));
+    GGML_ASSERT(ggml_is_contiguous(td));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S = k->ne[0];
+    const int64_t H = k->ne[2];
+    const int64_t n_tokens = k->ne[3];
+    const int64_t n_seqs = state->ne[1];
+    {
+        GGML_ASSERT(k->ne[1] == 1);
+        GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
+        GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
+        // TODO: RWKV v4 and v5
+        GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    }
+
+    bool is_node = false;
+
+    if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
+        GGML_ABORT("fatal error"); // TODO: implement backward
+        is_node = true;
+    }
+
+    // concat output and new_state
+    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op   = GGML_OP_RWKV_WKV;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = k;
+    result->src[1] = v;
+    result->src[2] = r;
+    result->src[3] = tf;
+    result->src[4] = td;
+    result->src[5] = state;
+
+    return result;
+}
+
 // ggml_unary
 
 static struct ggml_tensor * ggml_unary_impl(
@@ -12126,6 +12196,48 @@ static void ggml_compute_forward_hardsigmoid(
     }
 }
 
+static void ggml_compute_forward_exp_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_exp_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_exp(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_exp_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 
 // ggml_compute_forward_norm
 
@@ -16704,6 +16816,10 @@ static void ggml_compute_forward_unary(
             {
                 ggml_compute_forward_hardsigmoid(params, dst);
             } break;
+        case GGML_UNARY_OP_EXP:
+            {
+                ggml_compute_forward_exp(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -16839,6 +16955,96 @@ static void ggml_compute_forward_add_rel_pos(
     }
 }
 
+// ggml_compute_forward_rwkv_wkv
+
+static void ggml_compute_forward_rwkv_wkv_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const size_t T = dst->src[1]->ne[3];
+    const size_t C = dst->ne[0];
+    const size_t H = dst->src[1]->ne[2];
+    const size_t n_seqs = dst->src[5]->ne[1];
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    memset(dst_data, 0, T * C * sizeof(float));
+
+    float * k =          (float *) dst->src[0]->data;
+    float * v =          (float *) dst->src[1]->data;
+    float * r =          (float *) dst->src[2]->data;
+    float * time_faaaa = (float *) dst->src[3]->data;
+    float * time_decay = (float *) dst->src[4]->data;
+
+    size_t t_stride = H * (C / H);
+
+    size_t h_stride = C / H;
+    size_t h_stride_2d = (C / H) * (C / H);
+
+    // basically fused operations:
+    // dst = r @ (time_faaaa * (k @ v) + state),
+    // state = time_decay * state + (k @ v),
+    // recursive through each token
+    for (size_t t = 0; t < T; t++) {
+        size_t t_offset = t * t_stride;
+        size_t state_offset = (C / H) * C * (t / (T / n_seqs));
+        float * state_cur = state + state_offset;
+        float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+        for (size_t h = 0; h < H; h++) {
+            size_t h_offset = h * h_stride;
+            size_t t_h_offset = t_offset + h_offset;
+            size_t h_2d_offset = h * h_stride_2d;
+
+            for (size_t i = 0; i < C / H; i++) {
+                size_t t_h_i_offset = t_h_offset + i;
+                size_t h_i_offset = h_offset + i;
+                size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                float k_val = k[t_h_i_offset];
+                float r_val = r[t_h_i_offset];
+                float time_faaaa_val = time_faaaa[h_i_offset];
+                // RWKV v6: different time_decay for each token.
+                float time_decay_val = time_decay[t_h_i_offset];
+
+                for (size_t j = 0; j < C / H; j ++) {
+                    size_t t_h_j_offset = t_h_offset + j;
+                    size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                    float v_val = v[t_h_j_offset];
+                    float kv_val = v_val * k_val;
+                    float prev_state_val = state_prev[h_2d_i_j_offset];
+                    float temp_val = kv_val * time_faaaa_val + prev_state_val;
+                    dst_data[t_h_j_offset] += temp_val * r_val;
+                    state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rwkv_wkv(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rwkv_wkv_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -17490,6 +17696,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_add_rel_pos(params, tensor);
             } break;
+        case GGML_OP_RWKV_WKV:
+            {
+                ggml_compute_forward_rwkv_wkv(params, tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -18607,12 +18817,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                         zero_table);
                             }
                         } break;
+                    case GGML_UNARY_OP_EXP:
+                        {
+                            if (src0->grad) {
+                                src0->grad = ggml_add_or_set(ctx,
+                                        src0->grad,
+                                        ggml_mul(ctx, tensor, tensor->grad),
+                                        zero_table);
+                            }
+                        } break;
                     default:
                         GGML_ABORT("fatal error");
                 }
             } break;
         case GGML_OP_GET_REL_POS:
         case GGML_OP_ADD_REL_POS:
+        case GGML_OP_RWKV_WKV:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
@@ -19036,6 +19256,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 case GGML_UNARY_OP_SIGMOID:
                 case GGML_UNARY_OP_HARDSWISH:
                 case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_EXP:
                     {
                         n_tasks = 1;
                     } break;
@@ -19127,6 +19348,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
+        case GGML_OP_RWKV_WKV:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
index b55effa9907b100106cc1fe5a88e2abbb7dd505d..a48c4fb676a46baa32dd2dec776896ea2bdc6d19 100644 (file)
@@ -94,6 +94,9 @@ class Keys:
         DECODER_START_TOKEN_ID            = "{arch}.decoder_start_token_id"
         ATTN_LOGIT_SOFTCAPPING            = "{arch}.attn_logit_softcapping"
         FINAL_LOGIT_SOFTCAPPING           = "{arch}.final_logit_softcapping"
+        RESCALE_EVERY_N_LAYERS            = "{arch}.rescale_every_n_layers"
+        TIME_MIX_EXTRA_DIM                = "{arch}.time_mix_extra_dim"
+        TIME_DECAY_EXTRA_DIM              = "{arch}.time_decay_extra_dim"
 
     class Attention:
         HEAD_COUNT        = "{arch}.attention.head_count"
@@ -132,6 +135,9 @@ class Keys:
         TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
         DT_B_C_RMS     = "{arch}.ssm.dt_b_c_rms"
 
+    class WKV:
+        HEAD_SIZE = "{arch}.wkv.head_size"
+
     class Tokenizer:
         MODEL                = "tokenizer.ggml.model"
         PRE                  = "tokenizer.ggml.pre"
@@ -207,6 +213,7 @@ class MODEL_ARCH(IntEnum):
     GEMMA        = auto()
     GEMMA2       = auto()
     STARCODER2   = auto()
+    RWKV6        = auto()
     MAMBA        = auto()
     XVERSE       = auto()
     COMMAND_R    = auto()
@@ -270,6 +277,29 @@ class MODEL_TENSOR(IntEnum):
     SSM_A                = auto()
     SSM_D                = auto()
     SSM_OUT              = auto()
+    TIME_MIX_W1          = auto()
+    TIME_MIX_W2          = auto()
+    TIME_MIX_LERP_X      = auto()
+    TIME_MIX_LERP_K      = auto()
+    TIME_MIX_LERP_V      = auto()
+    TIME_MIX_LERP_R      = auto()
+    TIME_MIX_LERP_G      = auto()
+    TIME_MIX_LERP_W      = auto()
+    TIME_MIX_FIRST       = auto()
+    TIME_MIX_DECAY       = auto()
+    TIME_MIX_DECAY_W1    = auto()
+    TIME_MIX_DECAY_W2    = auto()
+    TIME_MIX_KEY         = auto()
+    TIME_MIX_VALUE       = auto()
+    TIME_MIX_RECEPTANCE  = auto()
+    TIME_MIX_GATE        = auto()
+    TIME_MIX_LN          = auto()
+    TIME_MIX_OUTPUT      = auto()
+    CHANNEL_MIX_LERP_K   = auto()
+    CHANNEL_MIX_LERP_R   = auto()
+    CHANNEL_MIX_KEY      = auto()
+    CHANNEL_MIX_RECEPTANCE = auto()
+    CHANNEL_MIX_VALUE    = auto()
     ATTN_Q_A             = auto()
     ATTN_Q_B             = auto()
     ATTN_KV_A_MQA        = auto()
@@ -337,6 +367,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.GEMMA:          "gemma",
     MODEL_ARCH.GEMMA2:         "gemma2",
     MODEL_ARCH.STARCODER2:     "starcoder2",
+    MODEL_ARCH.RWKV6:          "rwkv6",
     MODEL_ARCH.MAMBA:          "mamba",
     MODEL_ARCH.XVERSE:         "xverse",
     MODEL_ARCH.COMMAND_R:      "command-r",
@@ -355,87 +386,110 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
-    MODEL_TENSOR.TOKEN_EMBD:           "token_embd",
-    MODEL_TENSOR.TOKEN_EMBD_NORM:      "token_embd_norm",
-    MODEL_TENSOR.TOKEN_TYPES:          "token_types",
-    MODEL_TENSOR.POS_EMBD:             "position_embd",
-    MODEL_TENSOR.OUTPUT_NORM:          "output_norm",
-    MODEL_TENSOR.OUTPUT:               "output",
-    MODEL_TENSOR.ROPE_FREQS:           "rope_freqs",
-    MODEL_TENSOR.ROPE_FACTORS_LONG:    "rope_factors_long",
-    MODEL_TENSOR.ROPE_FACTORS_SHORT:   "rope_factors_short",
-    MODEL_TENSOR.ATTN_NORM:            "blk.{bid}.attn_norm",
-    MODEL_TENSOR.ATTN_NORM_2:          "blk.{bid}.attn_norm_2",
-    MODEL_TENSOR.ATTN_QKV:             "blk.{bid}.attn_qkv",
-    MODEL_TENSOR.ATTN_Q:               "blk.{bid}.attn_q",
-    MODEL_TENSOR.ATTN_K:               "blk.{bid}.attn_k",
-    MODEL_TENSOR.ATTN_V:               "blk.{bid}.attn_v",
-    MODEL_TENSOR.ATTN_OUT:             "blk.{bid}.attn_output",
-    MODEL_TENSOR.ATTN_ROT_EMBD:        "blk.{bid}.attn_rot_embd",
-    MODEL_TENSOR.ATTN_Q_NORM:          "blk.{bid}.attn_q_norm",
-    MODEL_TENSOR.ATTN_K_NORM:          "blk.{bid}.attn_k_norm",
-    MODEL_TENSOR.ATTN_OUT_NORM:        "blk.{bid}.attn_output_norm",
-    MODEL_TENSOR.ATTN_POST_NORM:       "blk.{bid}.post_attention_norm",
-    MODEL_TENSOR.FFN_GATE_INP:         "blk.{bid}.ffn_gate_inp",
-    MODEL_TENSOR.FFN_GATE_INP_SHEXP:   "blk.{bid}.ffn_gate_inp_shexp",
-    MODEL_TENSOR.FFN_NORM:             "blk.{bid}.ffn_norm",
-    MODEL_TENSOR.FFN_PRE_NORM:         "blk.{bid}.ffn_norm",
-    MODEL_TENSOR.FFN_POST_NORM:        "blk.{bid}.post_ffw_norm",
-    MODEL_TENSOR.FFN_GATE:             "blk.{bid}.ffn_gate",
-    MODEL_TENSOR.FFN_DOWN:             "blk.{bid}.ffn_down",
-    MODEL_TENSOR.FFN_UP:               "blk.{bid}.ffn_up",
-    MODEL_TENSOR.FFN_GATE_SHEXP:       "blk.{bid}.ffn_gate_shexp",
-    MODEL_TENSOR.FFN_DOWN_SHEXP:       "blk.{bid}.ffn_down_shexp",
-    MODEL_TENSOR.FFN_UP_SHEXP:         "blk.{bid}.ffn_up_shexp",
-    MODEL_TENSOR.FFN_ACT:              "blk.{bid}.ffn",
-    MODEL_TENSOR.FFN_NORM_EXP:         "blk.{bid}.ffn_norm_exps",
-    MODEL_TENSOR.FFN_GATE_EXP:         "blk.{bid}.ffn_gate_exps",
-    MODEL_TENSOR.FFN_DOWN_EXP:         "blk.{bid}.ffn_down_exps",
-    MODEL_TENSOR.FFN_UP_EXP:           "blk.{bid}.ffn_up_exps",
-    MODEL_TENSOR.LAYER_OUT_NORM:       "blk.{bid}.layer_output_norm",
-    MODEL_TENSOR.SSM_IN:               "blk.{bid}.ssm_in",
-    MODEL_TENSOR.SSM_CONV1D:           "blk.{bid}.ssm_conv1d",
-    MODEL_TENSOR.SSM_X:                "blk.{bid}.ssm_x",
-    MODEL_TENSOR.SSM_DT:               "blk.{bid}.ssm_dt",
-    MODEL_TENSOR.SSM_A:                "blk.{bid}.ssm_a",
-    MODEL_TENSOR.SSM_D:                "blk.{bid}.ssm_d",
-    MODEL_TENSOR.SSM_OUT:              "blk.{bid}.ssm_out",
-    MODEL_TENSOR.ATTN_Q_A:             "blk.{bid}.attn_q_a",
-    MODEL_TENSOR.ATTN_Q_B:             "blk.{bid}.attn_q_b",
-    MODEL_TENSOR.ATTN_KV_A_MQA:        "blk.{bid}.attn_kv_a_mqa",
-    MODEL_TENSOR.ATTN_KV_B:            "blk.{bid}.attn_kv_b",
-    MODEL_TENSOR.ATTN_Q_A_NORM:        "blk.{bid}.attn_q_a_norm",
-    MODEL_TENSOR.ATTN_KV_A_NORM:       "blk.{bid}.attn_kv_a_norm",
-    MODEL_TENSOR.ATTN_SUB_NORM:        "blk.{bid}.attn_sub_norm",
-    MODEL_TENSOR.FFN_SUB_NORM:         "blk.{bid}.ffn_sub_norm",
-    MODEL_TENSOR.DEC_ATTN_NORM:        "dec.blk.{bid}.attn_norm",
-    MODEL_TENSOR.DEC_ATTN_Q:           "dec.blk.{bid}.attn_q",
-    MODEL_TENSOR.DEC_ATTN_K:           "dec.blk.{bid}.attn_k",
-    MODEL_TENSOR.DEC_ATTN_V:           "dec.blk.{bid}.attn_v",
-    MODEL_TENSOR.DEC_ATTN_OUT:         "dec.blk.{bid}.attn_o",
-    MODEL_TENSOR.DEC_ATTN_REL_B:       "dec.blk.{bid}.attn_rel_b",
-    MODEL_TENSOR.DEC_CROSS_ATTN_NORM:  "dec.blk.{bid}.cross_attn_norm",
-    MODEL_TENSOR.DEC_CROSS_ATTN_Q:     "dec.blk.{bid}.cross_attn_q",
-    MODEL_TENSOR.DEC_CROSS_ATTN_K:     "dec.blk.{bid}.cross_attn_k",
-    MODEL_TENSOR.DEC_CROSS_ATTN_V:     "dec.blk.{bid}.cross_attn_v",
-    MODEL_TENSOR.DEC_CROSS_ATTN_OUT:   "dec.blk.{bid}.cross_attn_o",
-    MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b",
-    MODEL_TENSOR.DEC_FFN_NORM:         "dec.blk.{bid}.ffn_norm",
-    MODEL_TENSOR.DEC_FFN_GATE:         "dec.blk.{bid}.ffn_gate",
-    MODEL_TENSOR.DEC_FFN_DOWN:         "dec.blk.{bid}.ffn_down",
-    MODEL_TENSOR.DEC_FFN_UP:           "dec.blk.{bid}.ffn_up",
-    MODEL_TENSOR.DEC_OUTPUT_NORM:      "dec.output_norm",
-    MODEL_TENSOR.ENC_ATTN_NORM:        "enc.blk.{bid}.attn_norm",
-    MODEL_TENSOR.ENC_ATTN_Q:           "enc.blk.{bid}.attn_q",
-    MODEL_TENSOR.ENC_ATTN_K:           "enc.blk.{bid}.attn_k",
-    MODEL_TENSOR.ENC_ATTN_V:           "enc.blk.{bid}.attn_v",
-    MODEL_TENSOR.ENC_ATTN_OUT:         "enc.blk.{bid}.attn_o",
-    MODEL_TENSOR.ENC_ATTN_REL_B:       "enc.blk.{bid}.attn_rel_b",
-    MODEL_TENSOR.ENC_FFN_NORM:         "enc.blk.{bid}.ffn_norm",
-    MODEL_TENSOR.ENC_FFN_GATE:         "enc.blk.{bid}.ffn_gate",
-    MODEL_TENSOR.ENC_FFN_DOWN:         "enc.blk.{bid}.ffn_down",
-    MODEL_TENSOR.ENC_FFN_UP:           "enc.blk.{bid}.ffn_up",
-    MODEL_TENSOR.ENC_OUTPUT_NORM:      "enc.output_norm",
+    MODEL_TENSOR.TOKEN_EMBD:                "token_embd",
+    MODEL_TENSOR.TOKEN_EMBD_NORM:           "token_embd_norm",
+    MODEL_TENSOR.TOKEN_TYPES:               "token_types",
+    MODEL_TENSOR.POS_EMBD:                  "position_embd",
+    MODEL_TENSOR.OUTPUT_NORM:               "output_norm",
+    MODEL_TENSOR.OUTPUT:                    "output",
+    MODEL_TENSOR.ROPE_FREQS:                "rope_freqs",
+    MODEL_TENSOR.ROPE_FACTORS_LONG:         "rope_factors_long",
+    MODEL_TENSOR.ROPE_FACTORS_SHORT:        "rope_factors_short",
+    MODEL_TENSOR.ATTN_NORM:                 "blk.{bid}.attn_norm",
+    MODEL_TENSOR.ATTN_NORM_2:               "blk.{bid}.attn_norm_2",
+    MODEL_TENSOR.ATTN_QKV:                  "blk.{bid}.attn_qkv",
+    MODEL_TENSOR.ATTN_Q:                    "blk.{bid}.attn_q",
+    MODEL_TENSOR.ATTN_K:                    "blk.{bid}.attn_k",
+    MODEL_TENSOR.ATTN_V:                    "blk.{bid}.attn_v",
+    MODEL_TENSOR.ATTN_OUT:                  "blk.{bid}.attn_output",
+    MODEL_TENSOR.ATTN_ROT_EMBD:             "blk.{bid}.attn_rot_embd",
+    MODEL_TENSOR.ATTN_Q_NORM:               "blk.{bid}.attn_q_norm",
+    MODEL_TENSOR.ATTN_K_NORM:               "blk.{bid}.attn_k_norm",
+    MODEL_TENSOR.ATTN_OUT_NORM:             "blk.{bid}.attn_output_norm",
+    MODEL_TENSOR.ATTN_POST_NORM:            "blk.{bid}.post_attention_norm",
+    MODEL_TENSOR.FFN_GATE_INP:              "blk.{bid}.ffn_gate_inp",
+    MODEL_TENSOR.FFN_GATE_INP_SHEXP:        "blk.{bid}.ffn_gate_inp_shexp",
+    MODEL_TENSOR.FFN_NORM:                  "blk.{bid}.ffn_norm",
+    MODEL_TENSOR.FFN_PRE_NORM:              "blk.{bid}.ffn_norm",
+    MODEL_TENSOR.FFN_POST_NORM:             "blk.{bid}.post_ffw_norm",
+    MODEL_TENSOR.FFN_GATE:                  "blk.{bid}.ffn_gate",
+    MODEL_TENSOR.FFN_DOWN:                  "blk.{bid}.ffn_down",
+    MODEL_TENSOR.FFN_UP:                    "blk.{bid}.ffn_up",
+    MODEL_TENSOR.FFN_GATE_SHEXP:            "blk.{bid}.ffn_gate_shexp",
+    MODEL_TENSOR.FFN_DOWN_SHEXP:            "blk.{bid}.ffn_down_shexp",
+    MODEL_TENSOR.FFN_UP_SHEXP:              "blk.{bid}.ffn_up_shexp",
+    MODEL_TENSOR.FFN_ACT:                   "blk.{bid}.ffn",
+    MODEL_TENSOR.FFN_NORM_EXP:              "blk.{bid}.ffn_norm_exps",
+    MODEL_TENSOR.FFN_GATE_EXP:              "blk.{bid}.ffn_gate_exps",
+    MODEL_TENSOR.FFN_DOWN_EXP:              "blk.{bid}.ffn_down_exps",
+    MODEL_TENSOR.FFN_UP_EXP:                "blk.{bid}.ffn_up_exps",
+    MODEL_TENSOR.LAYER_OUT_NORM:            "blk.{bid}.layer_output_norm",
+    MODEL_TENSOR.SSM_IN:                    "blk.{bid}.ssm_in",
+    MODEL_TENSOR.SSM_CONV1D:                "blk.{bid}.ssm_conv1d",
+    MODEL_TENSOR.SSM_X:                     "blk.{bid}.ssm_x",
+    MODEL_TENSOR.SSM_DT:                    "blk.{bid}.ssm_dt",
+    MODEL_TENSOR.SSM_A:                     "blk.{bid}.ssm_a",
+    MODEL_TENSOR.SSM_D:                     "blk.{bid}.ssm_d",
+    MODEL_TENSOR.SSM_OUT:                   "blk.{bid}.ssm_out",
+    MODEL_TENSOR.TIME_MIX_W1:               "blk.{bid}.time_mix_w1",
+    MODEL_TENSOR.TIME_MIX_W2:               "blk.{bid}.time_mix_w2",
+    MODEL_TENSOR.TIME_MIX_LERP_X:           "blk.{bid}.time_mix_lerp_x",
+    MODEL_TENSOR.TIME_MIX_LERP_K:           "blk.{bid}.time_mix_lerp_k",
+    MODEL_TENSOR.TIME_MIX_LERP_V:           "blk.{bid}.time_mix_lerp_v",
+    MODEL_TENSOR.TIME_MIX_LERP_R:           "blk.{bid}.time_mix_lerp_r",
+    MODEL_TENSOR.TIME_MIX_LERP_G:           "blk.{bid}.time_mix_lerp_g",
+    MODEL_TENSOR.TIME_MIX_LERP_W:           "blk.{bid}.time_mix_lerp_w",
+    MODEL_TENSOR.TIME_MIX_FIRST:            "blk.{bid}.time_mix_first",
+    MODEL_TENSOR.TIME_MIX_DECAY:            "blk.{bid}.time_mix_decay",
+    MODEL_TENSOR.TIME_MIX_DECAY_W1:         "blk.{bid}.time_mix_decay_w1",
+    MODEL_TENSOR.TIME_MIX_DECAY_W2:         "blk.{bid}.time_mix_decay_w2",
+    MODEL_TENSOR.TIME_MIX_KEY:              "blk.{bid}.time_mix_key",
+    MODEL_TENSOR.TIME_MIX_VALUE:            "blk.{bid}.time_mix_value",
+    MODEL_TENSOR.TIME_MIX_RECEPTANCE:       "blk.{bid}.time_mix_receptance",
+    MODEL_TENSOR.TIME_MIX_GATE:             "blk.{bid}.time_mix_gate",
+    MODEL_TENSOR.TIME_MIX_LN:               "blk.{bid}.time_mix_ln",
+    MODEL_TENSOR.TIME_MIX_OUTPUT:           "blk.{bid}.time_mix_output",
+    MODEL_TENSOR.CHANNEL_MIX_LERP_K:        "blk.{bid}.channel_mix_lerp_k",
+    MODEL_TENSOR.CHANNEL_MIX_LERP_R:        "blk.{bid}.channel_mix_lerp_r",
+    MODEL_TENSOR.CHANNEL_MIX_KEY:           "blk.{bid}.channel_mix_key",
+    MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE:    "blk.{bid}.channel_mix_receptance",
+    MODEL_TENSOR.CHANNEL_MIX_VALUE:         "blk.{bid}.channel_mix_value",
+    MODEL_TENSOR.ATTN_Q_A:                  "blk.{bid}.attn_q_a",
+    MODEL_TENSOR.ATTN_Q_B:                  "blk.{bid}.attn_q_b",
+    MODEL_TENSOR.ATTN_KV_A_MQA:             "blk.{bid}.attn_kv_a_mqa",
+    MODEL_TENSOR.ATTN_KV_B:                 "blk.{bid}.attn_kv_b",
+    MODEL_TENSOR.ATTN_Q_A_NORM:             "blk.{bid}.attn_q_a_norm",
+    MODEL_TENSOR.ATTN_KV_A_NORM:            "blk.{bid}.attn_kv_a_norm",
+    MODEL_TENSOR.ATTN_SUB_NORM:             "blk.{bid}.attn_sub_norm",
+    MODEL_TENSOR.FFN_SUB_NORM:              "blk.{bid}.ffn_sub_norm",
+    MODEL_TENSOR.DEC_ATTN_NORM:             "dec.blk.{bid}.attn_norm",
+    MODEL_TENSOR.DEC_ATTN_Q:                "dec.blk.{bid}.attn_q",
+    MODEL_TENSOR.DEC_ATTN_K:                "dec.blk.{bid}.attn_k",
+    MODEL_TENSOR.DEC_ATTN_V:                "dec.blk.{bid}.attn_v",
+    MODEL_TENSOR.DEC_ATTN_OUT:              "dec.blk.{bid}.attn_o",
+    MODEL_TENSOR.DEC_ATTN_REL_B:            "dec.blk.{bid}.attn_rel_b",
+    MODEL_TENSOR.DEC_CROSS_ATTN_NORM:       "dec.blk.{bid}.cross_attn_norm",
+    MODEL_TENSOR.DEC_CROSS_ATTN_Q:          "dec.blk.{bid}.cross_attn_q",
+    MODEL_TENSOR.DEC_CROSS_ATTN_K:          "dec.blk.{bid}.cross_attn_k",
+    MODEL_TENSOR.DEC_CROSS_ATTN_V:          "dec.blk.{bid}.cross_attn_v",
+    MODEL_TENSOR.DEC_CROSS_ATTN_OUT:        "dec.blk.{bid}.cross_attn_o",
+    MODEL_TENSOR.DEC_CROSS_ATTN_REL_B:      "dec.blk.{bid}.cross_attn_rel_b",
+    MODEL_TENSOR.DEC_FFN_NORM:              "dec.blk.{bid}.ffn_norm",
+    MODEL_TENSOR.DEC_FFN_GATE:              "dec.blk.{bid}.ffn_gate",
+    MODEL_TENSOR.DEC_FFN_DOWN:              "dec.blk.{bid}.ffn_down",
+    MODEL_TENSOR.DEC_FFN_UP:                "dec.blk.{bid}.ffn_up",
+    MODEL_TENSOR.DEC_OUTPUT_NORM:           "dec.output_norm",
+    MODEL_TENSOR.ENC_ATTN_NORM:             "enc.blk.{bid}.attn_norm",
+    MODEL_TENSOR.ENC_ATTN_Q:                "enc.blk.{bid}.attn_q",
+    MODEL_TENSOR.ENC_ATTN_K:                "enc.blk.{bid}.attn_k",
+    MODEL_TENSOR.ENC_ATTN_V:                "enc.blk.{bid}.attn_v",
+    MODEL_TENSOR.ENC_ATTN_OUT:              "enc.blk.{bid}.attn_o",
+    MODEL_TENSOR.ENC_ATTN_REL_B:            "enc.blk.{bid}.attn_rel_b",
+    MODEL_TENSOR.ENC_FFN_NORM:              "enc.blk.{bid}.ffn_norm",
+    MODEL_TENSOR.ENC_FFN_GATE:              "enc.blk.{bid}.ffn_gate",
+    MODEL_TENSOR.ENC_FFN_DOWN:              "enc.blk.{bid}.ffn_down",
+    MODEL_TENSOR.ENC_FFN_UP:                "enc.blk.{bid}.ffn_up",
+    MODEL_TENSOR.ENC_OUTPUT_NORM:           "enc.output_norm",
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -856,6 +910,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.RWKV6: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_NORM_2,
+        MODEL_TENSOR.TIME_MIX_W1,
+        MODEL_TENSOR.TIME_MIX_W2,
+        MODEL_TENSOR.TIME_MIX_LERP_X,
+        MODEL_TENSOR.TIME_MIX_LERP_K,
+        MODEL_TENSOR.TIME_MIX_LERP_V,
+        MODEL_TENSOR.TIME_MIX_LERP_R,
+        MODEL_TENSOR.TIME_MIX_LERP_G,
+        MODEL_TENSOR.TIME_MIX_LERP_W,
+        MODEL_TENSOR.TIME_MIX_FIRST,
+        MODEL_TENSOR.TIME_MIX_DECAY,
+        MODEL_TENSOR.TIME_MIX_DECAY_W1,
+        MODEL_TENSOR.TIME_MIX_DECAY_W2,
+        MODEL_TENSOR.TIME_MIX_KEY,
+        MODEL_TENSOR.TIME_MIX_VALUE,
+        MODEL_TENSOR.TIME_MIX_RECEPTANCE,
+        MODEL_TENSOR.TIME_MIX_GATE,
+        MODEL_TENSOR.TIME_MIX_LN,
+        MODEL_TENSOR.TIME_MIX_OUTPUT,
+        MODEL_TENSOR.CHANNEL_MIX_LERP_K,
+        MODEL_TENSOR.CHANNEL_MIX_LERP_R,
+        MODEL_TENSOR.CHANNEL_MIX_KEY,
+        MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE,
+        MODEL_TENSOR.CHANNEL_MIX_VALUE,
+    ],
     MODEL_ARCH.MAMBA: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index af3b98c679b0b66cd36d0a1ab5dafb8262936a0f..3c95c26730f7a4cc71fd8e4859946f765eef00cd 100644 (file)
@@ -670,6 +670,18 @@ class GGUFWriter:
     def add_expert_weights_scale(self, value: float) -> None:
         self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
 
+    def add_rescale_every_n_layers(self, count: int) -> None:
+        self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
+
+    def add_time_mix_extra_dim(self, dim: int) -> None:
+        self.add_uint32(Keys.LLM.TIME_MIX_EXTRA_DIM.format(arch=self.arch), dim)
+
+    def add_time_decay_extra_dim(self, dim: int) -> None:
+        self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)
+
+    def add_wkv_head_size(self, size: int) -> None:
+        self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
+
     def add_layer_norm_eps(self, value: float) -> None:
         self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
 
index a4f185c0658a34b14abd6893c16e5bf8216258af..bc9a13ee5bdf5f81b8d65be69593d739c30690db 100644 (file)
@@ -27,6 +27,7 @@ class TensorNameMap:
             "embedding.word_embeddings",                 # chatglm
             "transformer.token_embeddings",              # openelm
             "shared",                                    # t5
+            "rwkv.embeddings",                           # rwkv
         ),
 
         # Token type embeddings
@@ -40,6 +41,7 @@ class TensorNameMap:
             "embeddings.LayerNorm",       # bert
             "emb_ln",                     # nomic-bert
             "transformer.norm",           # openelm
+            "rwkv.blocks.0.pre_ln",       # rwkv
         ),
 
         # Position embeddings
@@ -57,6 +59,7 @@ class TensorNameMap:
             "word_embeddings_for_head",  # persimmon
             "lm_head.linear",            # phi2
             "output_layer",              # chatglm
+            "head",                      # rwkv
         ),
 
         # Output norm
@@ -76,6 +79,7 @@ class TensorNameMap:
             "encoder.final_layernorm",                 # chatglm
             "transformer.norm",                        # openelm
             "model.norm",                              # nemotron
+            "rwkv.ln_out",                             # rwkv
         ),
 
         # Rope frequencies
@@ -108,12 +112,14 @@ class TensorNameMap:
             "transformer.blocks.{bid}.norm_attn_norm.norm_1",       # dbrx
             "encoder.layers.{bid}.input_layernorm",                 # chatglm
             "transformer.layers.{bid}.attn_norm",                   # openelm
+            "rwkv.blocks.{bid}.ln1",                                # rwkv
         ),
 
         # Attention norm 2
         MODEL_TENSOR.ATTN_NORM_2: (
-            "transformer.h.{bid}.ln_attn",  # falcon40b
+            "transformer.h.{bid}.ln_attn",                  # falcon40b
             "encoder.layer.{bid}.layer_norm_1",             # jina-v2-code
+            "rwkv.blocks.{bid}.ln2",                        # rwkv
         ),
 
         # Attention query-key-value
@@ -434,6 +440,98 @@ class TensorNameMap:
             "backbone.layers.{bid}.mixer.out_proj",
         ),
 
+        MODEL_TENSOR.TIME_MIX_W1: (
+            "rwkv.blocks.{bid}.attention.time_maa_w1",  # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_W2: (
+            "rwkv.blocks.{bid}.attention.time_maa_w2",  # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_LERP_X: (
+            "rwkv.blocks.{bid}.attention.time_maa_x",   # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_LERP_K: (
+            "rwkv.blocks.{bid}.attention.time_maa_k",   # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_LERP_V: (
+            "rwkv.blocks.{bid}.attention.time_maa_v",   # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_LERP_R: (
+            "rwkv.blocks.{bid}.attention.time_maa_r",   # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_LERP_G: (
+            "rwkv.blocks.{bid}.attention.time_maa_g",   # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_LERP_W: (
+            "rwkv.blocks.{bid}.attention.time_maa_w",   # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_FIRST: (
+            "rwkv.blocks.{bid}.attention.time_faaaa",   # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_DECAY: (
+            "rwkv.blocks.{bid}.attention.time_decay",   # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_DECAY_W1: (
+            "rwkv.blocks.{bid}.attention.time_decay_w1",  # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_DECAY_W2: (
+            "rwkv.blocks.{bid}.attention.time_decay_w2",  # rwkv v6
+        ),
+
+        MODEL_TENSOR.TIME_MIX_KEY: (
+            "rwkv.blocks.{bid}.attention.key", # rwkv
+        ),
+
+        MODEL_TENSOR.TIME_MIX_VALUE: (
+            "rwkv.blocks.{bid}.attention.value", # rwkv
+        ),
+
+        MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
+            "rwkv.blocks.{bid}.attention.receptance", # rwkv
+        ),
+
+        MODEL_TENSOR.TIME_MIX_GATE: (
+            "rwkv.blocks.{bid}.attention.gate", # rwkv
+        ),
+
+        MODEL_TENSOR.TIME_MIX_LN: (
+            "rwkv.blocks.{bid}.attention.ln_x", # rwkv
+        ),
+
+        MODEL_TENSOR.TIME_MIX_OUTPUT: (
+            "rwkv.blocks.{bid}.attention.output", # rwkv
+        ),
+
+        MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
+            "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
+        ),
+
+        MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
+            "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6
+        ),
+
+        MODEL_TENSOR.CHANNEL_MIX_KEY: (
+            "rwkv.blocks.{bid}.feed_forward.key", # rwkv
+        ),
+
+        MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
+            "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv
+        ),
+
+        MODEL_TENSOR.CHANNEL_MIX_VALUE: (
+            "rwkv.blocks.{bid}.feed_forward.value", # rwkv
+        ),
+
         MODEL_TENSOR.ATTN_Q_A: (
             "model.layers.{bid}.self_attn.q_a_proj", # deepseek2
         ),
index f2e701602bb00defb493560b7332aab74d0f8504..bfc37e88bbb749e8a0dcac3744a4371c14a68166 100644 (file)
@@ -66,6 +66,7 @@ extern "C" {
         LLAMA_VOCAB_TYPE_BPE  = 2, // GPT-2 tokenizer based on byte-level BPE
         LLAMA_VOCAB_TYPE_WPM  = 3, // BERT tokenizer based on WordPiece
         LLAMA_VOCAB_TYPE_UGM  = 4, // T5 tokenizer based on Unigram
+        LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
     };
 
     // pre-tokenization types
index dc4138cbc1dda8411e664dfd8484835ce007f919..2c007477e8da24341539bb8df5f2751ae112b624 100644 (file)
@@ -58,17 +58,17 @@ struct naive_trie {
         auto res = children.find(c);
         if (res != children.end()) {
             return res->second.get_longest_prefix(key, len, offset + 1);
-        } else {
-            return std::make_pair(key, offset);
         }
+
+        return std::make_pair(key, offset);
     }
-    struct naive_trie * traverse(const char c) {
+    const struct naive_trie * traverse(const char c) const {
         auto res = children.find(c);
         if (res != children.end()) {
             return &res->second;
-        } else {
-            return NULL;
         }
+
+        return NULL;
     }
     std::map<char, struct naive_trie> children;
     bool has_value;
@@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
             // traverse the token matcher trie to find a matching token
             bool single_codepoint_token_found = false;
             const struct best_tokenization & current_best = tokenization_results[input_offset];
-            struct naive_trie * node  = token_matcher.traverse(normalized[prefix_offset++]);
+            const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
 
             while (prefix_offset <= input_len && node != NULL) {
                 // check if we found valid token in prefix
@@ -1097,6 +1097,111 @@ private:
     struct naive_trie token_matcher;
 };
 
+//
+// RWKV tokenizer
+//
+
+static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
+    std::vector<uint8_t> output;
+    output.reserve(escaped.size());
+
+    // Parser state
+    bool escaping = false;
+    uint8_t hex_remaining = 0;
+    uint8_t hex_acc = 0;
+
+    // Step through characters, performing parsing
+    for (const char & c : escaped) {
+        // If we're parsing a hex code, interpret the next character
+        if (hex_remaining != 0) {
+            uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
+            hex_acc = (hex_acc << 4) + value;
+
+            hex_remaining -= 1;
+            if (hex_remaining == 0) {
+                output.push_back(hex_acc);
+                hex_acc = 0;
+            }
+
+            continue;
+        }
+
+        // If we got an escape character, interpret it
+        if (escaping) {
+            if (c == 't') {
+                output.push_back('\t');
+            } else if (c == 'n') {
+                output.push_back('\n');
+            } else if (c == 'r') {
+                output.push_back('\r');
+            } else if (c == 'x') {
+                hex_remaining = 2;
+            } else {
+                output.push_back(c);
+            }
+
+            escaping = false;
+            continue;
+        }
+
+        if (c == '\\') {
+            escaping = true;
+            continue;
+        }
+
+        output.push_back(c);
+    }
+
+    return output;
+}
+
+struct llm_tokenizer_rwkv {
+    llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
+        // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
+        // For now, we decode the vocab here into the lookup we'll use for tokenization.
+
+        // build trie
+        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
+            const auto & token = vocab.id_to_token[id];
+            const auto data = llama_unescape_rwkv_token(token.text);
+            token_matcher.insert((const char *) data.data(), data.size(), id);
+        }
+    }
+
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+        uint32_t position = 0;
+
+        while (position < text.size()) {
+            const struct naive_trie * node = token_matcher.traverse(text[position]);
+            if (node == NULL) {
+                // no matching token found, add unknown token
+                output.push_back(vocab.special_unk_id);
+                position += 1;
+                continue;
+            }
+
+            // traverse the trie to find the longest matching token
+            uint32_t token_id = 0;
+            uint32_t token_length = 0;
+            while (node != NULL) {
+                if (node->has_value) {
+                    token_id = node->value;
+                    token_length = position + 1;
+                }
+                node = node->traverse(text[++position]);
+            }
+
+            // add the longest matching token
+            output.push_back(token_id);
+            position = token_length;
+        }
+    }
+
+    const llama_vocab & vocab;
+
+    struct naive_trie token_matcher;
+};
+
 //
 // (de-) tokenize
 //
@@ -1401,6 +1506,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
                     output.push_back(vocab.special_eos_id);
                 }
             } break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            {
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+
+                        llm_tokenizer_rwkv tokenizer(vocab);
+                        tokenizer.tokenize(raw_text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
         case LLAMA_VOCAB_TYPE_NONE:
             GGML_ABORT("fatal error");
     }
@@ -1616,6 +1738,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
                 }
                 break;
             }
+            case LLAMA_VOCAB_TYPE_RWKV: {
+                std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
+
+                // If we don't have enough space, return an error
+                if (result.size() > (size_t)length) {
+                    return -(int)result.size();
+                }
+
+                memcpy(buf, result.data(), result.size());
+                return (int)result.size();
+            }
             default:
                 GGML_ABORT("fatal error");
         }
index 2274296b45406366bbef2a686aeffeb3dd3de8ba..2113c72f3c90b95abb8284e51deeb3d7b255492f 100644 (file)
@@ -212,6 +212,7 @@ enum llm_arch {
     LLM_ARCH_JAIS,
     LLM_ARCH_NEMOTRON,
     LLM_ARCH_EXAONE,
+    LLM_ARCH_RWKV6,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -259,6 +260,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_JAIS,            "jais"         },
     { LLM_ARCH_NEMOTRON,        "nemotron"     },
     { LLM_ARCH_EXAONE,          "exaone"       },
+    { LLM_ARCH_RWKV6,           "rwkv6"        },
     { LLM_ARCH_UNKNOWN,         "(unknown)"    },
 };
 
@@ -295,6 +297,9 @@ enum llm_kv {
     LLM_KV_DECODER_START_TOKEN_ID,
     LLM_KV_ATTN_LOGIT_SOFTCAPPING,
     LLM_KV_FINAL_LOGIT_SOFTCAPPING,
+    LLM_KV_RESCALE_EVERY_N_LAYERS,
+    LLM_KV_TIME_MIX_EXTRA_DIM,
+    LLM_KV_TIME_DECAY_EXTRA_DIM,
 
     LLM_KV_ATTENTION_HEAD_COUNT,
     LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -330,6 +335,8 @@ enum llm_kv {
     LLM_KV_SSM_TIME_STEP_RANK,
     LLM_KV_SSM_DT_B_C_RMS,
 
+    LLM_KV_WKV_HEAD_SIZE,
+
     LLM_KV_TOKENIZER_MODEL,
     LLM_KV_TOKENIZER_PRE,
     LLM_KV_TOKENIZER_LIST,
@@ -389,11 +396,14 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_EXPERT_USED_COUNT,                 "%s.expert_used_count"                 },
     { LLM_KV_EXPERT_SHARED_COUNT,               "%s.expert_shared_count"               },
     { LLM_KV_EXPERT_WEIGHTS_SCALE,              "%s.expert_weights_scale"              },
-    { LLM_KV_POOLING_TYPE ,                     "%s.pooling_type"                      },
+    { LLM_KV_POOLING_TYPE                     "%s.pooling_type"                      },
     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
     { LLM_KV_ATTN_LOGIT_SOFTCAPPING,            "%s.attn_logit_softcapping"            },
     { LLM_KV_FINAL_LOGIT_SOFTCAPPING,           "%s.final_logit_softcapping"           },
+    { LLM_KV_RESCALE_EVERY_N_LAYERS,            "%s.rescale_every_n_layers"            },
+    { LLM_KV_TIME_MIX_EXTRA_DIM,                "%s.time_mix_extra_dim"                },
+    { LLM_KV_TIME_DECAY_EXTRA_DIM,              "%s.time_decay_extra_dim"              },
 
     { LLM_KV_ATTENTION_HEAD_COUNT,             "%s.attention.head_count"             },
     { LLM_KV_ATTENTION_HEAD_COUNT_KV,          "%s.attention.head_count_kv"          },
@@ -429,6 +439,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_SSM_TIME_STEP_RANK,            "%s.ssm.time_step_rank" },
     { LLM_KV_SSM_DT_B_C_RMS,                "%s.ssm.dt_b_c_rms" },
 
+    { LLM_KV_WKV_HEAD_SIZE,                 "%s.wkv.head_size" },
+
     { LLM_KV_TOKENIZER_MODEL,                "tokenizer.ggml.model"                    },
     { LLM_KV_TOKENIZER_PRE,                  "tokenizer.ggml.pre"                      },
     { LLM_KV_TOKENIZER_LIST,                 "tokenizer.ggml.tokens"                   },
@@ -518,6 +530,29 @@ enum llm_tensor {
     LLM_TENSOR_SSM_A,
     LLM_TENSOR_SSM_D,
     LLM_TENSOR_SSM_OUT,
+    LLM_TENSOR_TIME_MIX_W1,
+    LLM_TENSOR_TIME_MIX_W2,
+    LLM_TENSOR_TIME_MIX_LERP_X,
+    LLM_TENSOR_TIME_MIX_LERP_W,
+    LLM_TENSOR_TIME_MIX_LERP_K,
+    LLM_TENSOR_TIME_MIX_LERP_V,
+    LLM_TENSOR_TIME_MIX_LERP_R,
+    LLM_TENSOR_TIME_MIX_LERP_G,
+    LLM_TENSOR_TIME_MIX_FIRST,
+    LLM_TENSOR_TIME_MIX_DECAY,
+    LLM_TENSOR_TIME_MIX_DECAY_W1,
+    LLM_TENSOR_TIME_MIX_DECAY_W2,
+    LLM_TENSOR_TIME_MIX_KEY,
+    LLM_TENSOR_TIME_MIX_VALUE,
+    LLM_TENSOR_TIME_MIX_RECEPTANCE,
+    LLM_TENSOR_TIME_MIX_GATE,
+    LLM_TENSOR_TIME_MIX_LN,
+    LLM_TENSOR_TIME_MIX_OUTPUT,
+    LLM_TENSOR_CHANNEL_MIX_LERP_K,
+    LLM_TENSOR_CHANNEL_MIX_LERP_R,
+    LLM_TENSOR_CHANNEL_MIX_KEY,
+    LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,
+    LLM_TENSOR_CHANNEL_MIX_VALUE,
     LLM_TENSOR_ATTN_Q_A,
     LLM_TENSOR_ATTN_Q_B,
     LLM_TENSOR_ATTN_KV_A_MQA,
@@ -1339,6 +1374,40 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_RWKV6,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,                "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM,           "token_embd_norm" },
+            { LLM_TENSOR_OUTPUT_NORM,               "output_norm" },
+            { LLM_TENSOR_OUTPUT,                    "output" },
+            { LLM_TENSOR_ATTN_NORM,                 "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_NORM_2,               "blk.%d.attn_norm_2" },
+            { LLM_TENSOR_TIME_MIX_W1,               "blk.%d.time_mix_w1" },
+            { LLM_TENSOR_TIME_MIX_W2,               "blk.%d.time_mix_w2" },
+            { LLM_TENSOR_TIME_MIX_LERP_X,           "blk.%d.time_mix_lerp_x" },
+            { LLM_TENSOR_TIME_MIX_LERP_W,           "blk.%d.time_mix_lerp_w" },
+            { LLM_TENSOR_TIME_MIX_LERP_K,           "blk.%d.time_mix_lerp_k" },
+            { LLM_TENSOR_TIME_MIX_LERP_V,           "blk.%d.time_mix_lerp_v" },
+            { LLM_TENSOR_TIME_MIX_LERP_R,           "blk.%d.time_mix_lerp_r" },
+            { LLM_TENSOR_TIME_MIX_LERP_G,           "blk.%d.time_mix_lerp_g" },
+            { LLM_TENSOR_TIME_MIX_FIRST,            "blk.%d.time_mix_first" },
+            { LLM_TENSOR_TIME_MIX_DECAY,            "blk.%d.time_mix_decay" },
+            { LLM_TENSOR_TIME_MIX_DECAY_W1,         "blk.%d.time_mix_decay_w1" },
+            { LLM_TENSOR_TIME_MIX_DECAY_W2,         "blk.%d.time_mix_decay_w2" },
+            { LLM_TENSOR_TIME_MIX_KEY,              "blk.%d.time_mix_key" },
+            { LLM_TENSOR_TIME_MIX_VALUE,            "blk.%d.time_mix_value" },
+            { LLM_TENSOR_TIME_MIX_RECEPTANCE,       "blk.%d.time_mix_receptance" },
+            { LLM_TENSOR_TIME_MIX_GATE,             "blk.%d.time_mix_gate" },
+            { LLM_TENSOR_TIME_MIX_LN,               "blk.%d.time_mix_ln" },
+            { LLM_TENSOR_TIME_MIX_OUTPUT,           "blk.%d.time_mix_output" },
+            { LLM_TENSOR_CHANNEL_MIX_LERP_K,        "blk.%d.channel_mix_lerp_k" },
+            { LLM_TENSOR_CHANNEL_MIX_LERP_R,        "blk.%d.channel_mix_lerp_r" },
+            { LLM_TENSOR_CHANNEL_MIX_KEY,           "blk.%d.channel_mix_key" },
+            { LLM_TENSOR_CHANNEL_MIX_VALUE,         "blk.%d.channel_mix_value" },
+            { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,    "blk.%d.channel_mix_receptance" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -2151,6 +2220,7 @@ enum e_model {
     MODEL_1B,
     MODEL_1_3B,
     MODEL_1_4B,
+    MODEL_1_6B,
     MODEL_2B,
     MODEL_2_8B,
     MODEL_3B,
@@ -2228,6 +2298,12 @@ struct llama_hparams {
     float f_attn_logit_softcapping = 50.0f;
     float f_final_logit_softcapping = 30.0f;
 
+    // for RWKV
+    uint32_t rescale_every_n_layers = 0;
+    uint32_t time_mix_extra_dim = 0;
+    uint32_t time_decay_extra_dim = 0;
+    uint32_t wkv_head_size = 0;
+
     float    rope_attn_factor = 1.0f;
     float    rope_freq_base_train;
     float    rope_freq_scale_train;
@@ -2291,6 +2367,11 @@ struct llama_hparams {
         if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
         if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
 
+        if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
+        if (this->time_mix_extra_dim     != other.time_mix_extra_dim)     return true;
+        if (this->time_decay_extra_dim   != other.time_decay_extra_dim)   return true;
+        if (this->wkv_head_size          != other.wkv_head_size)          return true;
+
         if (this->dec_start_token_id != other.dec_start_token_id) return true;
 
         const float EPSILON = 1e-9f;
@@ -2354,15 +2435,25 @@ struct llama_hparams {
     }
 
     uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
-        // corresponds to Mamba's conv_states size
-        // TODO: maybe support other convolution strides than 1
-        // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
-        return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
+        // corresponds to Mamba's conv_states size or RWKV's token_shift states size
+        if (wkv_head_size != 0) {
+            // for RWKV models
+            return 2 * n_embd;
+        } else {
+            // TODO: maybe support other convolution strides than 1
+            // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
+            return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
+        }
     }
 
     uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
-        // corresponds to Mamba's ssm_states size
-        return ssm_d_state * ssm_d_inner;
+        if (wkv_head_size != 0) {
+            // corresponds to RWKV's wkv_states size
+            return n_embd * wkv_head_size;
+        } else {
+            // corresponds to Mamba's ssm_states size
+            return ssm_d_state * ssm_d_inner;
+        }
     }
 };
 
@@ -2501,6 +2592,36 @@ struct llama_layer {
     struct ggml_tensor * ssm_conv1d_b;
     struct ggml_tensor * ssm_dt_b;
 
+    // rwkv
+    struct ggml_tensor * time_mix_w1;
+    struct ggml_tensor * time_mix_w2;
+    struct ggml_tensor * time_mix_lerp_x;
+    struct ggml_tensor * time_mix_lerp_w;
+    struct ggml_tensor * time_mix_lerp_k;
+    struct ggml_tensor * time_mix_lerp_v;
+    struct ggml_tensor * time_mix_lerp_r;
+    struct ggml_tensor * time_mix_lerp_g;
+
+    struct ggml_tensor * time_mix_first;
+    struct ggml_tensor * time_mix_decay;
+    struct ggml_tensor * time_mix_decay_w1;
+    struct ggml_tensor * time_mix_decay_w2;
+    struct ggml_tensor * time_mix_key;
+    struct ggml_tensor * time_mix_value;
+    struct ggml_tensor * time_mix_receptance;
+    struct ggml_tensor * time_mix_gate;
+
+    struct ggml_tensor * time_mix_ln;
+    struct ggml_tensor * time_mix_ln_b;
+    struct ggml_tensor * time_mix_output;
+
+    struct ggml_tensor * channel_mix_lerp_k;
+    struct ggml_tensor * channel_mix_lerp_r;
+
+    struct ggml_tensor * channel_mix_key;
+    struct ggml_tensor * channel_mix_receptance;
+    struct ggml_tensor * channel_mix_value;
+
     // long rope factors
     struct ggml_tensor * rope_long  = nullptr;
     struct ggml_tensor * rope_short = nullptr;
@@ -3426,7 +3547,7 @@ static bool llama_kv_cache_find_slot(
     const uint32_t n_seq_tokens = batch.n_seq_tokens;
 
     if (cache.recurrent) {
-        // For recurrent state architectures (like Mamba),
+        // For recurrent state architectures (like Mamba or RWKV),
         // each cache cell can store the state for a whole sequence.
         // A slot should be always be contiguous.
 
@@ -3675,7 +3796,7 @@ static bool llama_kv_cache_seq_rm(
     if (p0 < 0) p0 = 0;
     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
 
-    // models like Mamba can't have a state partially erased
+    // models like Mamba or RWKV can't have a state partially erased
     if (cache.recurrent) {
         if (seq_id >= (int64_t) cache.size) {
             // could be fatal
@@ -3811,7 +3932,7 @@ static void llama_kv_cache_seq_add(
     if (p0 == p1) return;
 
     if (cache.recurrent) {
-        // for Mamba-like models, only the pos needs to be shifted
+        // for Mamba-like or RWKV models, only the pos needs to be shifted
         if (0 <= seq_id && seq_id < (int64_t) cache.size) {
             const int32_t tail_id = cache.cells[seq_id].tail;
             if (tail_id >= 0) {
@@ -3860,7 +3981,7 @@ static void llama_kv_cache_seq_div(
     if (p0 == p1) return;
 
     if (cache.recurrent) {
-        // for Mamba-like models, only the pos needs to be changed
+        // for Mamba-like or RWKV models, only the pos needs to be changed
         if (0 <= seq_id && seq_id < (int64_t) cache.size) {
             const int32_t tail_id = cache.cells[seq_id].tail;
             if (tail_id >= 0) {
@@ -5051,6 +5172,7 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_1B:            return "1B";
         case MODEL_1_3B:          return "1.3B";
         case MODEL_1_4B:          return "1.4B";
+        case MODEL_1_6B:          return "1.6B";
         case MODEL_2B:            return "2B";
         case MODEL_2_8B:          return "2.8B";
         case MODEL_3B:            return "3B";
@@ -5097,6 +5219,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
         case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
         case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
         case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
         default:                    return "unknown";
     }
 }
@@ -5793,6 +5916,26 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_RWKV6:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
+                ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
+                ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
+                ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
+
+                switch (hparams.n_layer) {
+                    case 24: model.type = e_model::MODEL_1_6B; break;
+                    case 32:
+                        switch (hparams.n_embd) {
+                            case 2560: model.type = e_model::MODEL_3B; break;
+                            case 4096: model.type = e_model::MODEL_7B; break;
+                            default: model.type = e_model::MODEL_UNKNOWN;
+                        } break;
+                    case 61: model.type = e_model::MODEL_14B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -5922,6 +6065,15 @@ static void llm_load_vocab(
                 }
 #endif
             }
+        } else if (tokenizer_model == "rwkv") {
+            vocab.type = LLAMA_VOCAB_TYPE_RWKV;
+
+            // default special tokens
+            vocab.special_bos_id = -1;
+            vocab.special_eos_id = -1;
+            vocab.special_unk_id = -1;
+            vocab.special_sep_id = -1;
+            vocab.special_pad_id = -1;
         } else {
             throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
         }
@@ -6053,6 +6205,12 @@ static void llm_load_vocab(
             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
             vocab.tokenizer_add_bos = false;
             vocab.tokenizer_add_eos = true;
+        } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
+            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            vocab.tokenizer_add_space_prefix = false;
+            vocab.tokenizer_clean_spaces = false;
+            vocab.tokenizer_add_bos = false;
+            vocab.tokenizer_add_eos = false;
         } else {
             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
         }
@@ -6157,6 +6315,10 @@ static void llm_load_vocab(
         }
     } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
         vocab.linefeed_id = vocab.special_pad_id;
+    } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
+        const std::vector<int> ids = llama_tokenize_internal(vocab, "\n", false);
+        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+        vocab.linefeed_id = ids[0];
     } else {
         const std::vector<int> ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
         GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
@@ -8203,6 +8365,68 @@ static bool llm_load_tensors(
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                     }
                 } break;
+            case LLM_ARCH_RWKV6:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // Block 0, LN0
+                    model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
+                    model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
+
+                    // output
+                    model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                    model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
+                    model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
+
+                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+                    const int head_size = hparams.wkv_head_size;
+                    const int attn_hidden_size = n_embd;
+                    const int ffn_size = hparams.n_ff_arr[0];
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+
+                        layer.attn_norm_2   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd});
+                        layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd});
+
+                        layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5});
+                        layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5});
+
+                        layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1});
+                        layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1});
+                        layer.time_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1});
+                        layer.time_mix_lerp_v = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1});
+                        layer.time_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1});
+                        layer.time_mix_lerp_g = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1});
+
+                        layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size});
+                        layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd});
+                        layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim});
+                        layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size});
+                        layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd});
+                        layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd});
+                        layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd});
+                        layer.time_mix_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd});
+
+                        layer.time_mix_ln = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd});
+                        layer.time_mix_ln_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd});
+                        layer.time_mix_output = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size});
+
+                        layer.channel_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1});
+                        layer.channel_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1});
+
+                        layer.channel_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size});
+                        layer.channel_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd});
+                        layer.channel_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd});
+                    }
+
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -9162,6 +9386,171 @@ static struct ggml_tensor * llm_build_mamba(
     return cur;
 }
 
+static struct ggml_tensor * llm_build_rwkv6_time_mix(
+        struct llama_context & lctx,
+        struct ggml_context * ctx,
+        const struct llama_layer * layer,
+        struct ggml_tensor * cur,
+        struct ggml_tensor * x_prev,
+        struct ggml_tensor ** wkv_state) {
+    size_t n_embed      = cur->ne[0];
+    size_t n_seq_tokens = cur->ne[1];
+    size_t n_seqs       = cur->ne[2];
+
+    size_t head_size  = layer->time_mix_first->ne[0];
+    size_t head_count = layer->time_mix_first->ne[1];
+
+    size_t n_tokens = n_seqs * n_seq_tokens;
+
+    struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
+
+    sx  = ggml_reshape_2d(ctx, sx,  n_embed, n_tokens);
+    cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
+
+    struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
+
+    xxx = ggml_reshape_4d(
+        ctx,
+        ggml_tanh(
+            ctx,
+            ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
+        ),
+        layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
+    );
+
+    xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2));
+
+    xxx = ggml_mul_mat(
+        ctx,
+        ggml_reshape_4d(
+            ctx,
+            layer->time_mix_w2,
+            layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
+        ),
+        xxx
+    );
+
+    struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], 0);
+    struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * sizeof(float));
+    struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 2 * sizeof(float));
+    struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 3 * sizeof(float));
+    struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 4 * sizeof(float));
+
+    struct ggml_tensor * xw = ggml_add(
+        ctx,
+        ggml_mul(
+            ctx,
+            ggml_add(ctx, mw, layer->time_mix_lerp_w),
+            sx
+        ),
+        cur
+    );
+
+    struct ggml_tensor * xk = ggml_add(
+        ctx,
+        ggml_mul(
+            ctx,
+            ggml_add(ctx, mk, layer->time_mix_lerp_k),
+            sx
+        ),
+        cur
+    );
+
+    struct ggml_tensor * xv = ggml_add(
+        ctx,
+        ggml_mul(
+            ctx,
+            ggml_add(ctx, mv, layer->time_mix_lerp_v),
+            sx
+        ),
+        cur
+    );
+
+    struct ggml_tensor * xr = ggml_add(
+        ctx,
+        ggml_mul(
+            ctx,
+            ggml_add(ctx, mr, layer->time_mix_lerp_r),
+            sx
+        ),
+        cur
+    );
+
+    struct ggml_tensor * xg = ggml_add(
+        ctx,
+        ggml_mul(
+            ctx,
+            ggml_add(ctx, mg, layer->time_mix_lerp_g),
+            sx
+        ),
+        cur
+    );
+
+    struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1,         head_count, n_tokens);
+    struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk), 1,         head_size, head_count, n_tokens);
+    struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv), head_size, 1,         head_count, n_tokens);
+    struct ggml_tensor * g = ggml_silu(
+        ctx,
+        llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
+    );
+
+    struct ggml_tensor * w = ggml_mul_mat(
+        ctx,
+        layer->time_mix_decay_w2,
+        ggml_tanh(
+            ctx,
+            ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
+        )
+    );
+
+    w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed));
+    w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
+    w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
+
+    k = ggml_transpose(ctx, k);
+    v = ggml_transpose(ctx, v);
+    r = ggml_transpose(ctx, r);
+
+    struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
+    *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float));
+
+    // group norm with head_count groups
+    cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens);
+    cur = ggml_norm(ctx, cur, 64e-5f);
+
+    // Convert back to regular vectors.
+    cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
+    cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+
+    cur = ggml_mul(ctx, cur, g);
+    cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
+
+    return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs);
+}
+
+static struct ggml_tensor * llm_build_rwkv6_channel_mix(
+        struct llama_context & lctx,
+        struct ggml_context * ctx,
+        const struct llama_layer * layer,
+        struct ggml_tensor * cur,
+        struct ggml_tensor * x_prev) {
+    struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
+    struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
+    struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
+
+    struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr));
+    struct ggml_tensor * k = ggml_sqr(
+        ctx,
+        ggml_relu(
+            ctx,
+            llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
+        )
+    );
+
+    return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
+}
+
 struct llm_build_context {
     const llama_model    & model;
           llama_context  & lctx;
@@ -14683,6 +15072,117 @@ struct llm_build_context {
 
         return gf;
     }
+
+    ggml_cgraph * build_rwkv6() {
+        ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // Token shift state dimensions should be 2 * n_emb
+        GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
+
+        const int64_t n_seqs = batch.n_seqs;
+        const int64_t n_seq_tokens = batch.n_seq_tokens;
+        const int64_t n_tokens = batch.n_tokens;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(batch.equal_seqs);
+        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+        struct ggml_tensor * state_copy = build_inp_s_copy();
+        struct ggml_tensor * state_mask = build_inp_s_mask();
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            // (ab)using the KV cache to store the states
+            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.k_l[il], state_copy, state_mask,
+                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
+            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.v_l[il], state_copy, state_mask,
+                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
+
+            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs);
+
+            struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
+            struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
+
+            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il);
+            struct ggml_tensor * x_prev = ggml_concat(
+                ctx0,
+                att_shift,
+                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
+                1
+            );
+
+            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
+            ggml_build_forward_expand(gf, cur);
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_states,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self.v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+                    )
+                )
+            );
+
+            struct ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il);
+            x_prev = ggml_concat(
+                ctx0,
+                ffn_shift,
+                ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
+                1
+            );
+            cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
+            ggml_build_forward_expand(gf, cur);
+
+            struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
+            struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn));
+
+            token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1);
+
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
+                    ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
+                )
+            );
+
+            if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
+                cur = ggml_scale(ctx0, cur, 0.5F);
+            }
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -14929,6 +15429,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_exaone();
             } break;
+        case LLM_ARCH_RWKV6:
+            {
+                result = llm.build_rwkv6();
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -16973,6 +17477,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         // NOTE: can't use LLM_TN here because the layer number is not known
         quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
 
+        // do not quantize RWKV's time_mix_first tensors
+        quantize &= name.find("time_mix_first.weight") == std::string::npos;
+        quantize &= name.find("time_mix_w1.weight") == std::string::npos;
+        quantize &= name.find("time_mix_w2.weight") == std::string::npos;
+
         // do not quantize relative position bias (T5)
         quantize &= name.find("attn_rel_b.weight") == std::string::npos;
 
@@ -17977,6 +18486,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_T5:
         case LLM_ARCH_T5ENCODER:
         case LLM_ARCH_JAIS:
+        case LLM_ARCH_RWKV6:
             return LLAMA_ROPE_TYPE_NONE;
 
         // use what we call a normal RoPE, operating on pairs of consecutive head values
@@ -18145,6 +18655,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
 bool llama_model_is_recurrent(const struct llama_model * model) {
     switch (model->arch) {
         case LLM_ARCH_MAMBA:  return true;
+        case LLM_ARCH_RWKV6:  return true;
         default:              return false;
     }
 }