]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama: Add support for RWKV v7 architecture (#12412)
authorMolly Sophia <redacted>
Mon, 17 Mar 2025 23:27:50 +0000 (07:27 +0800)
committerGitHub <redacted>
Mon, 17 Mar 2025 23:27:50 +0000 (07:27 +0800)
* ggml: Add op l2_norm

Signed-off-by: Molly Sophia <redacted>
* ggml: Add op rwkv_wkv7

Signed-off-by: Molly Sophia <redacted>
* llama: Add support for RWKV7 and ARWKV7 models

Signed-off-by: Molly Sophia <redacted>
* llama: fix inference with RWKV6Qwen2

Signed-off-by: Molly Sophia <redacted>
* llama: add more (a)rwkv7 variants in size

Signed-off-by: Molly Sophia <redacted>
* Apply code-format changes

Signed-off-by: Molly Sophia <redacted>
* fix MUSA build

Signed-off-by: Molly Sophia <redacted>
* llama: fix shape error with rwkv using llama-parallel

Signed-off-by: Molly Sophia <redacted>
---------

Signed-off-by: Molly Sophia <redacted>
36 files changed:
convert_hf_to_gguf.py
ggml/include/ggml.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/norm.cu
ggml/src/ggml-cuda/norm.cuh
ggml/src/ggml-cuda/wkv.cu [new file with mode: 0644]
ggml/src/ggml-cuda/wkv.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/wkv6.cu [deleted file]
ggml/src/ggml-cuda/wkv6.cuh [deleted file]
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
ggml/src/ggml-sycl/backend.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/norm.cpp
ggml/src/ggml-sycl/norm.hpp
ggml/src/ggml-sycl/wkv.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/wkv.hpp [new file with mode: 0644]
ggml/src/ggml-sycl/wkv6.cpp [deleted file]
ggml/src/ggml-sycl/wkv6.hpp [deleted file]
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp [new file with mode: 0644]
ggml/src/ggml.c
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-hparams.h
src/llama-model.cpp
src/llama-model.h
src/llama-quant.cpp
tests/test-backend-ops.cpp

index b5d95bd5639f3a294a0184de69a05d02f48d69ef..d13d57c54154a55df3b8a50c1595f4e551a9fb7f 100755 (executable)
@@ -908,6 +908,40 @@ class Model:
         special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
         special_vocab.add_to_gguf(self.gguf_writer)
 
+    def _set_vocab_rwkv_world(self):
+        assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
+        vocab_size = self.hparams.get("vocab_size", 65536)
+
+        tokens: list[bytes] = ['<s>'.encode("utf-8")]
+        toktypes: list[int] = [gguf.TokenType.CONTROL]
+
+        with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
+            lines = f.readlines()
+            for line in lines:
+                parts = line.split(' ')
+                assert len(parts) >= 3
+                token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
+                token = token.encode("utf-8") if isinstance(token, str) else token
+                assert isinstance(token, bytes)
+                assert len(token) == token_len
+                token_text: str = repr(token)[2:-1]  # "b'\xff'" -> "\xff"
+                tokens.append(token_text.encode("utf-8"))
+                toktypes.append(gguf.TokenType.NORMAL)
+        remainder = vocab_size - len(tokens)
+        assert remainder >= 0
+        for i in range(len(tokens), vocab_size):
+            tokens.append(f"[PAD{i}]".encode("utf-8"))
+            toktypes.append(gguf.TokenType.UNUSED)
+
+        self.gguf_writer.add_tokenizer_model("rwkv")
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_types(toktypes)
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
+        special_vocab.chat_template = "rwkv-world"
+        # hack: Add '\n\n' as the EOT token to make it chat normally
+        special_vocab._set_special_token("eot", 261)
+        special_vocab.add_to_gguf(self.gguf_writer)
+
     def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
         tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
         logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
@@ -3412,38 +3446,7 @@ class Rwkv6Model(Model):
     model_arch = gguf.MODEL_ARCH.RWKV6
 
     def set_vocab(self):
-        assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
-        vocab_size = self.hparams.get("vocab_size", 65536)
-
-        tokens: list[bytes] = ['<s>'.encode("utf-8")]
-        toktypes: list[int] = [gguf.TokenType.CONTROL]
-
-        with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
-            lines = f.readlines()
-            for line in lines:
-                parts = line.split(' ')
-                assert len(parts) >= 3
-                token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
-                token = token.encode("utf-8") if isinstance(token, str) else token
-                assert isinstance(token, bytes)
-                assert len(token) == token_len
-                token_text: str = repr(token)[2:-1]  # "b'\xff'" -> "\xff"
-                tokens.append(token_text.encode("utf-8"))
-                toktypes.append(gguf.TokenType.NORMAL)
-        remainder = vocab_size - len(tokens)
-        assert remainder >= 0
-        for i in range(len(tokens), vocab_size):
-            tokens.append(f"[PAD{i}]".encode("utf-8"))
-            toktypes.append(gguf.TokenType.UNUSED)
-
-        self.gguf_writer.add_tokenizer_model("rwkv")
-        self.gguf_writer.add_token_list(tokens)
-        self.gguf_writer.add_token_types(toktypes)
-        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
-        special_vocab.chat_template = "rwkv-world"
-        # hack: Add '\n\n' as the EOT token to make it chat normally
-        special_vocab._set_special_token("eot", 261)
-        special_vocab.add_to_gguf(self.gguf_writer)
+        self._set_vocab_rwkv_world()
 
     def set_gguf_parameters(self):
         block_count = self.hparams["num_hidden_layers"]
@@ -3565,6 +3568,168 @@ class RWKV6Qwen2Model(Rwkv6Model):
             yield (new_name, data)
 
 
+@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
+class Rwkv7Model(Model):
+    model_arch = gguf.MODEL_ARCH.RWKV7
+
+    def set_vocab(self):
+        self._set_vocab_rwkv_world()
+
+    def calc_lora_rank(self, hidden_size, exponent, multiplier):
+        return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
+
+    def set_gguf_parameters(self):
+        block_count = self.hparams["num_hidden_layers"]
+        try:
+            head_size = self.hparams["head_size"]
+            layer_norm_eps = self.hparams["layer_norm_epsilon"]
+        except KeyError:
+            head_size = self.hparams["head_dim"]
+            layer_norm_eps = self.hparams["norm_eps"]
+        hidden_size = self.hparams["hidden_size"]
+        intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
+
+        # ICLR: In-Context-Learning-Rate
+        try:
+            lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
+            lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
+            lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
+            lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
+        except KeyError:
+            lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
+            lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
+            lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
+            lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
+
+        # RWKV isn't context limited
+        self.gguf_writer.add_context_length(1048576)
+        self.gguf_writer.add_embedding_length(hidden_size)
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
+        self.gguf_writer.add_wkv_head_size(head_size)
+        self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
+        self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
+        self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
+        self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
+        self.gguf_writer.add_feed_forward_length(intermediate_size)
+        self.gguf_writer.add_file_type(self.ftype)
+
+        # required by llama.cpp, unused
+        self.gguf_writer.add_head_count(0)
+
+    lerp_weights: dict[int, dict[str, Tensor]] = {}
+    lora_needs_transpose: bool = True
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # unify tensor names here to make life easier
+        name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
+        name = name.replace("self_attn", "attention").replace("attn", "attention")
+        name = name.replace("time_mixer.", "")
+        # lora layer names in fla-hub's impl
+        if "_lora.lora" in name:
+            self.lora_needs_transpose = False
+        name = name.replace("_lora.lora.0.weight", "1.weight")
+        name = name.replace("_lora.lora.2.weight", "2.weight")
+        name = name.replace("_lora.lora.2.bias", "0.weight")
+
+        name = name.replace("feed_forward_norm", "ln2")
+        name = name.replace("g_norm", "ln_x")
+
+        if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
+            # some models have dummy v0/v1/v2 on first layer while others don't
+            # ignore them all since they are not used
+            return
+
+        wkv_has_gate = self.hparams.get("wkv_has_gate", True)
+        lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
+
+        if bid is not None and "attention.x_" in name:
+            if "attention.x_x" in name:
+                # already concatenated
+                new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
+                data = data_torch.reshape(len(lerp_list), 1, 1, -1)
+                yield (new_name, data)
+            else:
+                try:
+                    self.lerp_weights[bid][name] = data_torch
+                except KeyError:
+                    self.lerp_weights[bid] = {name: data_torch}
+                if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
+                    new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
+                    data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
+                    yield (new_name, data)
+            return
+        else:
+            data_torch = data_torch.squeeze()
+            new_name = self.map_tensor_name(name)
+
+            if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
+                new_name += ".weight"
+
+            if self.lora_needs_transpose and any(
+                new_name.endswith(t) for t in [
+                    "time_mix_w1.weight", "time_mix_w2.weight",
+                    "time_mix_a1.weight", "time_mix_a2.weight",
+                    "time_mix_v1.weight", "time_mix_v2.weight",
+                    "time_mix_g1.weight", "time_mix_g2.weight",
+                ]
+            ):
+                data_torch = data_torch.transpose(0, 1)
+
+            if 'r_k' in new_name:
+                data_torch = data_torch.flatten()
+
+            if bid == 0 and "time_mix_a" in new_name:
+                # dummy v0/v1/v2 on first layer
+                # easist way to make llama happy
+                yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)
+
+            yield (new_name, data_torch)
+
+
+@Model.register("RwkvHybridForCausalLM")
+class ARwkv7Model(Rwkv7Model):
+    model_arch = gguf.MODEL_ARCH.ARWKV7
+
+    def set_vocab(self):
+        try:
+            self._set_vocab_sentencepiece()
+        except FileNotFoundError:
+            self._set_vocab_gpt2()
+
+    def set_gguf_parameters(self):
+        block_count = self.hparams["num_hidden_layers"]
+        hidden_size = self.hparams["hidden_size"]
+        head_size = self.hparams["head_size"]
+        rms_norm_eps = self.hparams["rms_norm_eps"]
+        intermediate_size = self.hparams["intermediate_size"]
+        wkv_has_gate = self.hparams["wkv_has_gate"]
+        assert self.hparams["wkv_version"] == 7
+
+        # ICLR: In-Context-Learning-Rate
+        lora_rank_decay = 64
+        lora_rank_iclr = 64
+        lora_rank_value_residual_mix = 32
+        lora_rank_gate = 128 if wkv_has_gate else 0
+
+        # RWKV isn't context limited
+        self.gguf_writer.add_context_length(1048576)
+        self.gguf_writer.add_embedding_length(hidden_size)
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
+        self.gguf_writer.add_wkv_head_size(head_size)
+        self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
+        self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
+        self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
+        self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
+        self.gguf_writer.add_feed_forward_length(intermediate_size)
+        self.gguf_writer.add_file_type(self.ftype)
+        self.gguf_writer.add_token_shift_count(1)
+
+        # required by llama.cpp, unused
+        self.gguf_writer.add_head_count(0)
+
+
 @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
 class MambaModel(Model):
     model_arch = gguf.MODEL_ARCH.MAMBA
index 2e5076d36a09ffb81fb4786dd885742ec770c312..cb3edb10d47026be000dee3382edb3948a5b18e2 100644 (file)
@@ -454,6 +454,7 @@ extern "C" {
         GGML_OP_RMS_NORM,
         GGML_OP_RMS_NORM_BACK,
         GGML_OP_GROUP_NORM,
+        GGML_OP_L2_NORM,
 
         GGML_OP_MUL_MAT,
         GGML_OP_MUL_MAT_ID,
@@ -502,6 +503,7 @@ extern "C" {
         GGML_OP_ADD_REL_POS,
         GGML_OP_RWKV_WKV6,
         GGML_OP_GATED_LINEAR_ATTN,
+        GGML_OP_RWKV_WKV7,
 
         GGML_OP_UNARY,
 
@@ -1095,6 +1097,18 @@ extern "C" {
             int                   n_groups,
             float                 eps);
 
+    // l2 normalize along rows
+    // used in rwkv v7
+    GGML_API struct ggml_tensor * ggml_l2_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
+    GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
     // a - x
     // b - dy
     GGML_API struct ggml_tensor * ggml_rms_norm_back(
@@ -1890,6 +1904,16 @@ extern "C" {
             struct ggml_tensor  * state,
             float scale);
 
+    GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * r,
+            struct ggml_tensor  * w,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * state);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
index f2ab4c5d695820c80f10c7860a7547a5b1e2ea20..75dc96b4786551cffc5da27fa6a0494281f89a2e 100644 (file)
@@ -8548,6 +8548,69 @@ static void ggml_compute_forward_group_norm(
     }
 }
 
+// ggml_compute_forward_l2_norm
+
+static void ggml_compute_forward_l2_norm_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    GGML_ASSERT(eps >= 0.0f);
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float sum = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)(x[i00] * x[i00]);
+                }
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                memcpy(y, x, ne00 * sizeof(float));
+
+                const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_l2_norm(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_l2_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_mul_mat
 
 static void ggml_compute_forward_mul_mat_one_chunk(
@@ -13604,6 +13667,184 @@ static void ggml_compute_forward_gla(
     }
 }
 
+// ggml_compute_forward_rwkv_wkv7
+
+static void ggml_compute_forward_rwkv_wkv7_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[1];
+    const int64_t n_seqs = dst->src[6]->ne[1];
+    const int64_t head_size = C / HEADS;
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * r = (float *) dst->src[0]->data;
+    float * w = (float *) dst->src[1]->data;
+    float * k = (float *) dst->src[2]->data;
+    float * v = (float *) dst->src[3]->data;
+    float * a = (float *) dst->src[4]->data;
+    float * b = (float *) dst->src[5]->data;
+
+    int64_t t_stride = HEADS * head_size; // Same to C
+
+    int64_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    int64_t h_stride_2d = head_size * head_size;
+
+    #if defined(GGML_SIMD)
+        for (int64_t t = 0; t < T; t++) {
+            int64_t t_offset = t * t_stride;
+            int64_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                int64_t h_offset = h * h_stride;
+                int64_t t_h_offset = t_offset + h_offset;
+                int64_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t ii = 0; ii < head_size; ii++) {
+                    int64_t t_h_i_offset = t_h_offset + ii;
+                    int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
+
+                    GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
+
+                    float sa = 0;
+                    {
+                        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+                        GGML_F32_VEC ax[GGML_F32_ARR];
+                        GGML_F32_VEC ay[GGML_F32_ARR];
+                        for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
+                            for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
+                                ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
+                                ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
+                                sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
+                            }
+                        }
+                        GGML_F32_VEC_REDUCE(sa, sum);
+                    }
+
+                    GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
+
+                    int64_t j = 0;
+                    GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+                    for (; j < head_size; j += GGML_F32_STEP) {
+                        for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
+                            int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
+                            int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
+
+                            GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
+                            GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
+                            GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
+                            GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
+
+                            k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
+
+                            GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
+                            // kv + s * decay + sa * b
+                            state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
+                            state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
+                            GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
+
+                            result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
+                        }
+                    }
+                    GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
+
+                    // There shouldn't be left-overs though.
+                    for (; j < head_size; j++) {
+                        int64_t t_h_j_offset = t_h_offset + j;
+                        int64_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float r_val = r[t_h_j_offset];
+                        float w_val = w[t_h_j_offset];
+                        float k_val = k[t_h_j_offset];
+                        float b_val = b[t_h_j_offset];
+                        float kv_val = v[t_h_i_offset] * k_val;
+
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
+                        dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
+                    }
+                }
+            }
+        }
+    #else
+        for (int64_t t = 0; t < T; t++) {
+            int64_t t_offset = t * t_stride;
+            int64_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                int64_t h_offset = h * h_stride;
+                int64_t t_h_offset = t_offset + h_offset;
+                int64_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    int64_t t_h_i_offset = t_h_offset + i;
+                    int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float v_val = v[t_h_i_offset];
+
+                    float sa = 0, result = 0;
+                    for (int64_t j = 0; j < head_size; j++) {
+                        sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
+                    }
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        int64_t t_h_j_offset = t_h_offset + j;
+                        int64_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float r_val = r[t_h_j_offset];
+                        float w_val = w[t_h_j_offset];
+                        float k_val = k[t_h_j_offset];
+                        float b_val = b[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
+                        result += state_cur[h_2d_i_j_offset] * r_val;
+                    }
+                    dst_data[t_h_i_offset] = result;
+                }
+            }
+        }
+    #endif
+}
+
+
+static void ggml_compute_forward_rwkv_wkv7(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rwkv_wkv7_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -14170,6 +14411,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_group_norm(params, tensor);
             } break;
+        case GGML_OP_L2_NORM:
+            {
+                ggml_compute_forward_l2_norm(params, tensor);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 ggml_compute_forward_mul_mat(params, tensor);
@@ -14357,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_gla(params, tensor);
             } break;
+        case GGML_OP_RWKV_WKV7:
+            {
+                ggml_compute_forward_rwkv_wkv7(params, tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -14582,6 +14831,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
         case GGML_OP_RMS_NORM_BACK:
+        case GGML_OP_L2_NORM:
         case GGML_OP_GROUP_NORM:
         case GGML_OP_CONCAT:
         case GGML_OP_MUL_MAT:
@@ -14648,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_FLASH_ATTN_BACK:
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
+        case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
+        case GGML_OP_RWKV_WKV7:
             {
                 n_tasks = n_threads;
             } break;
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
-        case GGML_OP_RWKV_WKV6:
-        case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
index 9bba398ce6be113c1c5a4b438e8393253cddefa9..8fb063822cfb74d5b4492615b99c3100c9d240b8 100644 (file)
@@ -36,7 +36,7 @@
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
-#include "ggml-cuda/wkv6.cuh"
+#include "ggml-cuda/wkv.cuh"
 #include "ggml-cuda/gla.cuh"
 #include "ggml.h"
 
@@ -2196,6 +2196,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GROUP_NORM:
             ggml_cuda_op_group_norm(ctx, dst);
             break;
+        case GGML_OP_L2_NORM:
+            ggml_cuda_op_l2_norm(ctx, dst);
+            break;
         case GGML_OP_CONCAT:
             ggml_cuda_op_concat(ctx, dst);
             break;
@@ -2304,6 +2307,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GATED_LINEAR_ATTN:
             ggml_cuda_op_gated_linear_attn(ctx, dst);
             break;
+        case GGML_OP_RWKV_WKV7:
+            ggml_cuda_op_rwkv_wkv7(ctx, dst);
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
             break;
@@ -3161,6 +3167,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             break;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_L2_NORM:
             return true;
         case GGML_OP_RMS_NORM_BACK:
             return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
@@ -3215,6 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_GATED_LINEAR_ATTN:
+        case GGML_OP_RWKV_WKV7:
             return true;
         case GGML_OP_FLASH_ATTN_EXT: {
 #ifndef FLASH_ATTN_AVAILABLE
index f127616eddade711fd7c200c58ac4e479d9ee89c..0020dbcec5fb59dae4d157b1771935ac65b5c868 100644 (file)
@@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
     }
 }
 
+// template <int block_size>
+// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
+//     const int row = blockIdx.x*blockDim.y + threadIdx.y;
+//     const int tid = threadIdx.x;
+
+//     float tmp = 0.0f; // partial sum for thread in warp
+
+//     for (int col = tid; col < ncols; col += block_size) {
+//         const float xi = x[row*ncols + col];
+//         tmp += xi * xi;
+//     }
+
+//     // sum up partial sums
+//     tmp = warp_reduce_sum(tmp);
+//     if (block_size > WARP_SIZE) {
+//         __shared__ float s_sum[32];
+//         int warp_id = threadIdx.x / WARP_SIZE;
+//         int lane_id = threadIdx.x % WARP_SIZE;
+//         if (lane_id == 0) {
+//             s_sum[warp_id] = tmp;
+//         }
+//         __syncthreads();
+//         tmp = s_sum[lane_id];
+//         tmp = warp_reduce_sum(tmp);
+//     }
+
+//     // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
+//     const float scale = rsqrtf(fmaxf(tmp, eps * eps));
+
+//     for (int col = tid; col < ncols; col += block_size) {
+//         dst[row*ncols + col] = scale * x[row*ncols + col];
+//     }
+// }
+
+template <int block_size>
+static __global__ void l2_norm_f32(
+        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+        const int64_t stride_sample, const float eps) {
+    const int nrows     = gridDim.x;
+    const int nchannels = gridDim.y;
+
+    const int row       = blockIdx.x;
+    const int channel   = blockIdx.y;
+    const int sample    = blockIdx.z;
+    const int tid       = threadIdx.x;
+
+    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
+    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int col = tid; col < ncols; col += block_size) {
+        const float xi = x[col];
+        tmp += xi * xi;
+    }
+
+    // sum up partial sums
+    tmp = warp_reduce_sum(tmp);
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
+        __shared__ float s_sum[32];
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
+    const float scale = rsqrtf(fmaxf(tmp, eps * eps));
+
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[col] = scale * x[col];
+    }
+}
+
 static void norm_f32_cuda(
         const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
         const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
@@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
     }
 }
 
+static void l2_norm_f32_cuda(
+        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+    const dim3 blocks_num(nrows, nchannels, nsamples);
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+    }
+}
+
 void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *) src0->data;
@@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
 
     rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
 }
+
+void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
+
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
+
+    l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
+}
index d63d34380b0a709eda1c73bdf4a51f87e06f77e6..706a5660a680cb57de910dbb78e5d527d5fea09a 100644 (file)
@@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/wkv.cu b/ggml/src/ggml-cuda/wkv.cu
new file mode 100644 (file)
index 0000000..d2fced7
--- /dev/null
@@ -0,0 +1,199 @@
+#include "common.cuh"
+#include "wkv.cuh"
+
+template <int block_size>
+static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = block_size;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    __syncthreads();
+    _tf[tid] = tf[head_i * head_size + tid];
+    __syncthreads();
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        __syncthreads();
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4& k = (float4&)(_k[j]);
+            const float4& r = (float4&)(_r[j]);
+            const float4& tf = (float4&)(_tf[j]);
+            const float4& td = (float4&)(_td[j]);
+            float4& s = (float4&)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            y += r.x * (tf.x * kv.x + s.x);
+            y += r.y * (tf.y * kv.y + s.y);
+            y += r.z * (tf.z * kv.z + s.z);
+            y += r.w * (tf.w * kv.w + s.w);
+
+            s.x = s.x * td.x + kv.x;
+            s.y = s.y * td.y + kv.y;
+            s.z = s.z * td.z + kv.z;
+            s.w = s.w * td.w + kv.w;
+        }
+        dst[t] = y;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+template <int block_size>
+static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = block_size;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
+
+#ifndef GGML_USE_MUSA
+    #pragma unroll
+#endif
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
+    }
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _r[tid] = r[t];
+        _w[tid] = w[t];
+        _k[tid] = k[t];
+        _a[tid] = a[t];
+        _b[tid] = b[t];
+        __syncthreads();
+
+        float sa = 0;
+        #pragma unroll
+        for (int j = 0; j < head_size; j += 4)
+        {
+            const float4& a = (float4&)(_a[j]);
+            const float4& s = (float4&)(state[j]);
+            sa += a.x * s.x;
+            sa += a.y * s.y;
+            sa += a.z * s.z;
+            sa += a.w * s.w;
+        }
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4& r = (float4&)(_r[j]);
+            const float4& w = (float4&)(_w[j]);
+            const float4& k = (float4&)(_k[j]);
+            const float4& b = (float4&)(_b[j]);
+            float4& s = (float4&)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            s.x = s.x * w.x + kv.x + sa * b.x;
+            s.y = s.y * w.y + kv.y + sa * b.y;
+            s.z = s.z * w.z + kv.z + sa * b.z;
+            s.w = s.w * w.w + kv.w + sa * b.w;
+
+            y += s.x * r.x;
+            y += s.y * r.y;
+            y += s.z * r.z;
+            y += s.w * r.w;
+        }
+        dst[t] = y;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
+    }
+}
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * k_d  = (const float *)dst->src[0]->data;
+    const float * v_d  = (const float *)dst->src[1]->data;
+    const float * r_d  = (const float *)dst->src[2]->data;
+    const float * tf_d = (const float *)dst->src[3]->data;
+    const float * td_d = (const float *)dst->src[4]->data;
+    const float * s_d  = (const float *)dst->src[5]->data;
+
+    const int64_t B = dst->src[5]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
+
+    if (C / H == CUDA_WKV_BLOCK_SIZE) {
+        rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+    } else {
+        rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+    }
+}
+
+void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * r_d = (const float *)dst->src[0]->data;
+    const float * w_d = (const float *)dst->src[1]->data;
+    const float * k_d = (const float *)dst->src[2]->data;
+    const float * v_d = (const float *)dst->src[3]->data;
+    const float * a_d = (const float *)dst->src[4]->data;
+    const float * b_d = (const float *)dst->src[5]->data;
+    const float * s_d = (const float *)dst->src[6]->data;
+
+    const int64_t B = dst->src[6]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
+
+    if (C / H == CUDA_WKV_BLOCK_SIZE) {
+        rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
+    } else {
+        rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
+    }
+}
diff --git a/ggml/src/ggml-cuda/wkv.cuh b/ggml/src/ggml-cuda/wkv.cuh
new file mode 100644 (file)
index 0000000..9623dd7
--- /dev/null
@@ -0,0 +1,7 @@
+#include "common.cuh"
+
+#define CUDA_WKV_BLOCK_SIZE 64
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/wkv6.cu b/ggml/src/ggml-cuda/wkv6.cu
deleted file mode 100644 (file)
index bbdafbe..0000000
+++ /dev/null
@@ -1,89 +0,0 @@
-#include "common.cuh"
-#include "wkv6.cuh"
-
-static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
-    const int tid = threadIdx.x;
-    const int bid = blockIdx.x;
-
-    const int head_size = CUDA_WKV_BLOCK_SIZE;
-    const int batch_i = bid / H;
-    const int head_i = bid % H;
-    const int state_size = C * head_size;
-    const int n_seq_tokens = T / B;
-
-    float state[head_size];
-    __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
-
-    #pragma unroll
-    for (int i = 0; i < head_size; i++) {
-        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
-    }
-
-    __syncthreads();
-    _tf[tid] = tf[head_i * head_size + tid];
-    __syncthreads();
-
-    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
-        __syncthreads();
-        _k[tid] = k[t];
-        _r[tid] = r[t];
-        _td[tid] = td[t];
-        __syncthreads();
-
-        const float _v = v[t];
-        float y = 0;
-        for (int j = 0; j < head_size; j += 4) {
-            const float4& k = (float4&)(_k[j]);
-            const float4& r = (float4&)(_r[j]);
-            const float4& tf = (float4&)(_tf[j]);
-            const float4& td = (float4&)(_td[j]);
-            float4& s = (float4&)(state[j]);
-            float4 kv;
-
-            kv.x = k.x * _v;
-            kv.y = k.y * _v;
-            kv.z = k.z * _v;
-            kv.w = k.w * _v;
-
-            y += r.x * (tf.x * kv.x + s.x);
-            y += r.y * (tf.y * kv.y + s.y);
-            y += r.z * (tf.z * kv.z + s.z);
-            y += r.w * (tf.w * kv.w + s.w);
-
-            s.x = s.x * td.x + kv.x;
-            s.y = s.y * td.y + kv.y;
-            s.z = s.z * td.z + kv.z;
-            s.w = s.w * td.w + kv.w;
-        }
-        dst[t] = y;
-    }
-
-    #pragma unroll
-    for (int i = 0; i < head_size; i++) {
-        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
-    }
-}
-
-void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const float * k_d  = (const float *)dst->src[0]->data;
-    const float * v_d  = (const float *)dst->src[1]->data;
-    const float * r_d  = (const float *)dst->src[2]->data;
-    const float * tf_d = (const float *)dst->src[3]->data;
-    const float * td_d = (const float *)dst->src[4]->data;
-    const float * s_d  = (const float *)dst->src[5]->data;
-
-    const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[2];
-    const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[1];
-
-    float * dst_d = (float *)dst->data;
-
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
-    GGML_ASSERT(C % H == 0);
-    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
-
-    rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
-}
diff --git a/ggml/src/ggml-cuda/wkv6.cuh b/ggml/src/ggml-cuda/wkv6.cuh
deleted file mode 100644 (file)
index a7124ee..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-#include "common.cuh"
-
-#define CUDA_WKV_BLOCK_SIZE 64
-
-void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index a58c474eb007e4ccea031a40a003808e19570d97..1e954b4ceabd7c6685caf3a025678e31f2492f93 100644 (file)
@@ -285,6 +285,13 @@ typedef struct {
     float    eps;
 } ggml_metal_kargs_rms_norm;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_l2_norm;
+
 typedef struct {
     int64_t  ne00;
     int64_t  ne01;
index e51a4169a23bf6efbaaeea571bae5b03719d2bdc..af65e7d9f53d43a9ad2b2d9b970144efabec8e2f 100644 (file)
@@ -184,10 +184,13 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
+    GGML_METAL_KERNEL_TYPE_L2_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
     GGML_METAL_KERNEL_TYPE_NORM,
     GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
     GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
+    GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
+    GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -810,10 +813,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,               get_rows_iq4_xs,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                  get_rows_i32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                       l2_norm,                        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                  ssm_conv_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                  ssm_scan_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,                 rwkv_wkv6_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,                 rwkv_wkv7_f32,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,               mul_mv_bf16_f32,                has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,          mul_mv_bf16_f32_1row,           has_simdgroup_reduction && use_bfloat);
@@ -1251,6 +1257,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_GROUP_NORM:
             return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_RMS_NORM:
+        case GGML_OP_L2_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ARGMAX:
             return true;
@@ -1288,6 +1295,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
             return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
+        case GGML_OP_RWKV_WKV6:
+        case GGML_OP_RWKV_WKV7:
             return true;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
@@ -2216,6 +2225,83 @@ static void ggml_metal_encode_node(
 
                 [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
+        case GGML_OP_RWKV_WKV6:
+            {
+                const int64_t B = dst->src[5]->ne[1];
+                const int64_t T = dst->src[0]->ne[2];
+                const int64_t C = dst->ne[0];
+                const int64_t H = dst->src[0]->ne[1];
+
+                GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+                GGML_ASSERT(C % H == 0);
+                GGML_ASSERT(C / H == 64);
+
+                size_t offs_src3 = 0;
+                size_t offs_src4 = 0;
+                size_t offs_src5 = 0;
+
+                id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
+                id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
+                id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+                [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+                [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
+                [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
+
+                [encoder setBytes:&B length:sizeof(B) atIndex:7];
+                [encoder setBytes:&T length:sizeof(T) atIndex:8];
+                [encoder setBytes:&C length:sizeof(C) atIndex:9];
+                [encoder setBytes:&H length:sizeof(H) atIndex:10];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
+            } break;
+        case GGML_OP_RWKV_WKV7:
+            {
+                const int64_t B = dst->src[6]->ne[1];
+                const int64_t T = dst->src[0]->ne[2];
+                const int64_t C = dst->ne[0];
+                const int64_t H = dst->src[0]->ne[1];
+
+                GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
+                GGML_ASSERT(C % H == 0);
+                GGML_ASSERT(C / H == 64);
+
+                size_t offs_src3 = 0;
+                size_t offs_src4 = 0;
+                size_t offs_src5 = 0;
+                size_t offs_src6 = 0;
+
+                id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
+                id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
+                id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
+                id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+                [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+                [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
+                [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
+                [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:7];
+
+                [encoder setBytes:&B length:sizeof(B) atIndex:8];
+                [encoder setBytes:&T length:sizeof(T) atIndex:9];
+                [encoder setBytes:&C length:sizeof(C) atIndex:10];
+                [encoder setBytes:&H length:sizeof(H) atIndex:11];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 GGML_ASSERT(ne00 == ne10);
@@ -3122,6 +3208,42 @@ static void ggml_metal_encode_node(
 
                 const int64_t nrows = ggml_nrows(src0);
 
+                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_L2_NORM:
+            {
+                GGML_ASSERT(ne00 % 4 == 0);
+                GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+                float eps;
+                memcpy(&eps, dst->op_params, sizeof(float));
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
+
+                int nth = 32; // SIMD width
+
+                while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                    nth *= 2;
+                }
+
+                nth = MIN(nth, ne00/4);
+
+                ggml_metal_kargs_l2_norm args = {
+                    /*.ne00   =*/ ne00,
+                    /*.ne00_4 =*/ ne00/4,
+                    /*.nb01   =*/ nb01,
+                    /*.eps    =*/ eps,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+                const int64_t nrows = ggml_nrows(src0);
+
                 [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
         case GGML_OP_GROUP_NORM:
index ad9d42a3eaa9e75497b95e03814b95efd7a3f0b7..3cef81b797197a2160c43175654c813d9ac3288c 100644 (file)
@@ -1295,6 +1295,184 @@ kernel void kernel_ssm_scan_f32(
     }
 }
 
+kernel void kernel_rwkv_wkv6_f32(
+    device const float * k,
+    device const float * v,
+    device const float * r,
+    device const float * tf,
+    device const float * td,
+    device const float * state_in,
+    device       float * dst,
+    constant    uint & B,
+    constant    uint & T,
+    constant    uint & C,
+    constant    uint & H,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]])  {
+
+    const uint head_size = 64; // TODO: support head_size = 128
+    const uint batch_id = tgpig.x / H;
+    const uint head_id = tgpig.x % H;
+    const uint tid = tpitg.x;
+
+    if (batch_id >= B || head_id >= H) {
+        return;
+    }
+
+    const uint state_size = C * head_size;
+    const uint n_seq_tokens = T / B;
+
+    threadgroup float _k[head_size];
+    threadgroup float _r[head_size];
+    threadgroup float _tf[head_size];
+    threadgroup float _td[head_size];
+
+    float state[head_size];
+
+    for (uint i = 0; i < head_size; i++) {
+        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+                          + i * head_size + tid];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    _tf[tid] = tf[head_id * head_size + tid];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+    for (uint t = start_t; t < end_t; t += C) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        const float v_val = v[t];
+        float y = 0.0;
+
+        for (uint j = 0; j < head_size; j += 4) {
+            float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+            float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+            float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
+            float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
+            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+            float4 kv = k_vec * v_val;
+
+            float4 temp = tf_vec * kv + s_vec;
+            y += dot(r_vec, temp);
+
+            s_vec = s_vec * td_vec + kv;
+            state[j]   = s_vec[0];
+            state[j+1] = s_vec[1];
+            state[j+2] = s_vec[2];
+            state[j+3] = s_vec[3];
+        }
+
+        dst[t] = y;
+    }
+
+    for (uint i = 0; i < head_size; i++) {
+        dst[T * C + batch_id * state_size + head_id * head_size * head_size
+            + i * head_size + tid] = state[i];
+    }
+}
+
+kernel void kernel_rwkv_wkv7_f32(
+    device const float * r,
+    device const float * w,
+    device const float * k,
+    device const float * v,
+    device const float * a,
+    device const float * b,
+    device const float * state_in,
+    device       float * dst,
+    constant    uint & B,
+    constant    uint & T,
+    constant    uint & C,
+    constant    uint & H,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]])  {
+
+    const uint head_size = 64; // TODO: support head_size = 128
+    const uint batch_id = tgpig.x / H;
+    const uint head_id = tgpig.x % H;
+    const uint tid = tpitg.x;
+
+    if (batch_id >= B || head_id >= H) {
+        return;
+    }
+
+    const uint state_size = C * head_size;
+    const uint n_seq_tokens = T / B;
+
+    threadgroup float _r[head_size];
+    threadgroup float _w[head_size];
+    threadgroup float _k[head_size];
+    threadgroup float _a[head_size];
+    threadgroup float _b[head_size];
+
+    float state[head_size];
+
+    for (uint i = 0; i < head_size; i++) {
+        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+                          + tid * head_size + i];
+    }
+
+    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+    for (uint t = start_t; t < end_t; t += C) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        _r[tid] = r[t];
+        _w[tid] = w[t];
+        _k[tid] = k[t];
+        _a[tid] = a[t];
+        _b[tid] = b[t];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        const float v_val = v[t];
+        float y = 0.0, sa = 0.0;
+
+        float4 sa_vec(0.0);
+
+        for (int j = 0; j < head_size; j += 4) {
+            float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
+            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+            sa_vec += a_vec * s_vec;
+        }
+        sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
+
+        for (uint j = 0; j < head_size; j += 4) {
+            float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+            float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
+            float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+            float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
+            float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+            float4 kv = k_vec * v_val;
+
+            s_vec = s_vec * w_vec + kv + sa * b_vec;
+            y += dot(s_vec, r_vec);
+
+            state[j]   = s_vec[0];
+            state[j+1] = s_vec[1];
+            state[j+2] = s_vec[2];
+            state[j+3] = s_vec[3];
+        }
+
+        dst[t] = y;
+    }
+
+    for (uint i = 0; i < head_size; i++) {
+        dst[T * C + batch_id * state_size + head_id * head_size * head_size
+            + tid * head_size + i] = state[i];
+    }
+}
+
 kernel void kernel_argmax(
         device   const void * x,
         device      int32_t * dst,
@@ -1463,6 +1641,49 @@ kernel void kernel_rms_norm(
     }
 }
 
+kernel void kernel_l2_norm(
+        constant ggml_metal_kargs_l2_norm & args,
+        device const char * src0,
+        device       char * dst,
+        threadgroup float * shmem_f32 [[threadgroup(0)]],
+        uint   tgpig[[threadgroup_position_in_grid]],
+        ushort tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort   ntg[[threads_per_threadgroup]]) {
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = 0.0f;
+    }
+
+    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+    float sumf = 0.0f;
+
+    // parallel sum
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        sumf += dot(x[i00], x[i00]);
+    }
+    sumf = simd_sum(sumf);
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float scale = 1.0f/sqrt(max(sumf, args.eps));
+
+    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        y[i00] = x[i00] * scale;
+    }
+}
+
 kernel void kernel_group_norm(
         device const float * src0,
         device       float * dst,
index 577ff51fde5a87ffc55bf873b75eb5857153aa7f..73d807cab0be9950db3d9687a7f9b0d224376218 100644 (file)
@@ -26,7 +26,7 @@
 #include "softmax.hpp"
 #include "tsembd.hpp"
 #include "im2col.hpp"
-#include "wkv6.hpp"
+#include "wkv.hpp"
 #include "outprod.hpp"
 #include "element_wise.hpp"
 #include "cpy.hpp"
index 05984d8c5ac4ef5f55c1f2ea2008d08eebd0b974..477652ab283eef8242b885280d2b14d5340a394b 100644 (file)
@@ -2696,6 +2696,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
+static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
 static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
     ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
@@ -3410,6 +3416,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
         case GGML_OP_RMS_NORM:
             ggml_sycl_rms_norm(ctx, dst);
             break;
+        case GGML_OP_L2_NORM:
+            ggml_sycl_l2_norm(ctx, dst);
+            break;
         case GGML_OP_MUL_MAT:
             if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
                 return false;
@@ -3487,6 +3496,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
         case GGML_OP_RWKV_WKV6:
             ggml_sycl_op_rwkv_wkv6(ctx, dst);
             break;
+        case GGML_OP_RWKV_WKV7:
+            ggml_sycl_op_rwkv_wkv7(ctx, dst);
+            break;
         case GGML_OP_GATED_LINEAR_ATTN:
             ggml_sycl_op_gated_linear_attn(ctx, dst);
             break;
@@ -4012,6 +4024,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return (op->src[0]->type == GGML_TYPE_F32);
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_L2_NORM:
         case GGML_OP_GROUP_NORM:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_SCALE:
@@ -4045,6 +4058,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_RWKV_WKV7:
         case GGML_OP_GATED_LINEAR_ATTN:
             return true;
         default:
index 9cf2be15575d8ade92bb07c56a37fa221dac8c71..6439db21b2978ee76aefccd573665305ff542fe0 100644 (file)
@@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
     }
 }
 
+static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
+    const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
+    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+        item_ct1.get_local_id(1);
+    const int tid = item_ct1.get_local_id(2);
+    const int nthreads = item_ct1.get_local_range(2);
+    const int nwarps = nthreads / WARP_SIZE;
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int col = tid; col < ncols; col += block_size) {
+        const float xi = x[row * ncols + col];
+        tmp += xi * xi;
+    }
+
+    // sum up partial sums
+    tmp = warp_reduce_sum(tmp, item_ct1);
+    if (block_size > WARP_SIZE) {
+
+        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
+        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        /*
+        DPCT1118:3: SYCL group functions and algorithms must be encountered in
+        converged control flow. You may need to adjust the code.
+        */
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+        size_t nreduce = nwarps / WARP_SIZE;
+        tmp = 0.f;
+        for (size_t i = 0; i < nreduce; i += 1)
+        {
+            tmp += s_sum[lane_id + i * WARP_SIZE];
+        }
+        tmp = warp_reduce_sum(tmp, item_ct1);
+    }
+
+    const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
+
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[row * ncols + col] = scale * x[row * ncols + col];
+    }
+}
+
 static void norm_f32_sycl(const float* x, float* dst, const int ncols,
     const int nrows, const float eps,
     queue_ptr stream, int device) {
@@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
     }
 }
 
+static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
+    const int nrows, const float eps,
+    queue_ptr stream, int device) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
+    if (ncols < 1024) {
+        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+        stream->submit([&](sycl::handler& cgh) {
+            cgh.parallel_for(
+                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+                    block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    l2_norm_f32(x, dst, ncols, eps, item_ct1,
+                        nullptr, WARP_SIZE);
+                });
+            });
+    }
+    else {
+        const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
+        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
+        const sycl::range<3> block_dims(1, 1, work_group_size);
+        /*
+        DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        stream->submit([&](sycl::handler& cgh) {
+            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
+                cgh);
+            cgh.parallel_for(
+                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+                    block_dims),
+                [=](sycl::nd_item<3> item_ct1)
+                [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    l2_norm_f32(x, dst, ncols, eps, item_ct1,
+                        get_pointer(s_sum_acc_ct1), work_group_size);
+                });
+            });
+    }
+}
+
 void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
     ggml_tensor* dst, const float* src0_dd,
     const float* src1_dd, float* dst_dd,
@@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
     (void)dst;
     (void)src1_dd;
 }
+
+void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+    const ggml_tensor* src1, ggml_tensor* dst,
+    const float* src0_dd, const float* src1_dd,
+    float* dst_dd,
+    const queue_ptr& main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
+
+    (void)src1;
+    (void)dst;
+    (void)src1_dd;
+}
index a9ad9156fa33e0a500fe30f9dc31feed4ca81b66..11e91680cc496c67d532486ad7947beac02b1daa 100644 (file)
@@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
     float* dst_dd,
     const queue_ptr& main_stream);
 
+void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+    const ggml_tensor* src1, ggml_tensor* dst,
+    const float* src0_dd, const float* src1_dd,
+    float* dst_dd,
+    const queue_ptr& main_stream);
+
 #endif // GGML_SYCL_NORM_HPP
diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp
new file mode 100644 (file)
index 0000000..540f6fb
--- /dev/null
@@ -0,0 +1,305 @@
+#include <sycl/sycl.hpp>
+#include "wkv.hpp"
+
+constexpr int WKV_BLOCK_SIZE = 64;  // Matching CUDA_WKV_BLOCK_SIZE
+
+// Helper function for the main kernel
+template <int block_size>
+static void rwkv_wkv6_f32_kernel(
+    const int B, const int T, const int C, const int H,
+    const float* k, const float* v, const float* r,
+    const float* tf, const float* td, const float* s,
+    float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
+
+    const int tid = item_ct1.get_local_id(2);
+    const int bid = item_ct1.get_group(2);
+
+    const int head_size = block_size;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    // Set up shared memory pointers
+    float* _k = shared_mem;
+    float* _r = _k + head_size;
+    float* _tf = _r + head_size;
+    float* _td = _tf + head_size;
+
+    // Local state array
+    float state[block_size];
+
+    // Load initial state
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    // Sync threads before shared memory operations
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    // Load time-mixing parameters
+    _tf[tid] = tf[head_i * head_size + tid];
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    // Main sequence processing loop
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
+         t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
+         t += C) {
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+        // Load current timestep data to shared memory
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+        const float _v = v[t];
+        float y = 0;
+
+        // Process in chunks of 4 for better vectorization
+        sycl::float4 k4, r4, tf4, td4, s4;
+        #pragma unroll
+        for (int j = 0; j < head_size; j += 4) {
+            // Load data in vec4 chunks
+            k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+            r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+            tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
+            td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
+            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+            // Compute key-value product
+            sycl::float4 kv4 = k4 * _v;
+
+            // Accumulate weighted sum
+            y += sycl::dot(r4, tf4 * kv4 + s4);
+
+            // Update state
+            s4 = s4 * td4 + kv4;
+
+            // Store updated state
+            state[j] = s4.x();
+            state[j+1] = s4.y();
+            state[j+2] = s4.z();
+            state[j+3] = s4.w();
+        }
+
+        dst[t] = y;
+    }
+
+    // Save final state
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+template <int block_size>
+static void rwkv_wkv7_f32_kernel(
+    const int B, const int T, const int C, const int H,
+    const float* r, const float* w, const float* k, const float* v,
+    const float* a, const float* b, const float* s,
+    float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
+
+    const int tid = item_ct1.get_local_id(2);
+    const int bid = item_ct1.get_group(2);
+
+    const int head_size = block_size;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float* _r = shared_mem;
+    float* _w = _r + head_size;
+    float* _k = _w + head_size;
+    float* _a = _k + head_size;
+    float* _b = _a + head_size;
+
+    float state[block_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
+    }
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
+         t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
+         t += C) {
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+        _r[tid] = r[t];
+        _w[tid] = w[t];
+        _k[tid] = k[t];
+        _a[tid] = a[t];
+        _b[tid] = b[t];
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+        const float _v = v[t];
+        float y = 0, sa = 0;
+        sycl::float4 a4, s4;
+
+        #pragma unroll
+        for (int j = 0; j < head_size; j += 4) {
+            a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
+            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
+            sa += sycl::dot(a4, s4);
+        }
+
+        sycl::float4 r4, w4, k4, b4;
+        #pragma unroll
+        for (int j = 0; j < head_size; j += 4) {
+            r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+            w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
+            k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+            b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
+            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+            sycl::float4 kv4 = k4 * _v;
+
+            s4 = s4 * w4 + kv4 + sa * b4;
+            y += sycl::dot(r4, s4);
+
+            state[j] = s4.x();
+            state[j+1] = s4.y();
+            state[j+2] = s4.z();
+            state[j+3] = s4.w();
+        }
+
+        dst[t] = y;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
+    }
+}
+
+void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
+
+    const ggml_tensor *src0 = dst->src[0];
+    const ggml_tensor *src1 = dst->src[1];
+
+    const float* k_d = (const float*)dst->src[0]->data;
+    const float* v_d = (const float*)dst->src[1]->data;
+    const float* r_d = (const float*)dst->src[2]->data;
+    const float* tf_d = (const float*)dst->src[3]->data;
+    const float* td_d = (const float*)dst->src[4]->data;
+    const float* s_d = (const float*)dst->src[5]->data;
+    float* dst_d = (float*)dst->data;
+
+    const int64_t B = dst->src[5]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
+
+    dpct::queue_ptr stream = ctx.stream();
+
+    // Calculate execution configuration
+    const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
+    sycl::range<3> block_dims(1, 1, C / H);
+    sycl::range<3> grid_dims(1, 1, B * H);
+
+    // Submit kernel
+    if (C / H == WKV_BLOCK_SIZE) {
+        stream->submit([&](sycl::handler& cgh) {
+            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1) {
+                    rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
+                        B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
+                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
+                    );
+                });
+        });
+    } else {
+        stream->submit([&](sycl::handler& cgh) {
+            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1) {
+                    rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
+                        B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
+                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
+                    );
+                });
+        });
+    }
+
+    GGML_UNUSED(src0);
+    GGML_UNUSED(src1);
+}
+
+void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
+
+    const ggml_tensor *src0 = dst->src[0];
+    const ggml_tensor *src1 = dst->src[1];
+
+    const float* r_d = (const float*)dst->src[0]->data;
+    const float* w_d = (const float*)dst->src[1]->data;
+    const float* k_d = (const float*)dst->src[2]->data;
+    const float* v_d = (const float*)dst->src[3]->data;
+    const float* a_d = (const float*)dst->src[4]->data;
+    const float* b_d = (const float*)dst->src[5]->data;
+    const float* s_d = (const float*)dst->src[6]->data;
+    float* dst_d = (float*)dst->data;
+
+    const int64_t B = dst->src[6]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
+
+    dpct::queue_ptr stream = ctx.stream();
+
+    // Calculate execution configuration
+    const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
+    sycl::range<3> block_dims(1, 1, C / H);
+    sycl::range<3> grid_dims(1, 1, B * H);
+
+    // Submit kernel
+    if (C / H == WKV_BLOCK_SIZE) {
+        stream->submit([&](sycl::handler& cgh) {
+            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1) {
+                    rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
+                        B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
+                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
+                    );
+                });
+        });
+    } else {
+        stream->submit([&](sycl::handler& cgh) {
+            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
+
+            cgh.parallel_for(
+                sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1) {
+                    rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
+                        B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
+                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
+                    );
+                });
+        });
+    }
+
+    GGML_UNUSED(src0);
+    GGML_UNUSED(src1);
+}
diff --git a/ggml/src/ggml-sycl/wkv.hpp b/ggml/src/ggml-sycl/wkv.hpp
new file mode 100644 (file)
index 0000000..9f34a10
--- /dev/null
@@ -0,0 +1,10 @@
+#ifndef GGML_SYCL_WKV_HPP
+#define GGML_SYCL_WKV_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+#endif // GGML_SYCL_WKV_HPP
diff --git a/ggml/src/ggml-sycl/wkv6.cpp b/ggml/src/ggml-sycl/wkv6.cpp
deleted file mode 100644 (file)
index b54c209..0000000
+++ /dev/null
@@ -1,143 +0,0 @@
-#include <sycl/sycl.hpp>
-#include "wkv6.hpp"
-
-constexpr int WKV_BLOCK_SIZE = 64;  // Matching CUDA_WKV_BLOCK_SIZE
-
-// Helper function for the main kernel
-static void rwkv_wkv_f32_kernel(
-    const int B, const int T, const int C, const int H,
-    const float* k, const float* v, const float* r,
-    const float* tf, const float* td, const float* s,
-    float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
-
-    const int tid = item_ct1.get_local_id(2);
-    const int bid = item_ct1.get_group(2);
-
-    const int head_size = WKV_BLOCK_SIZE;
-    const int batch_i = bid / H;
-    const int head_i = bid % H;
-    const int state_size = C * head_size;
-    const int n_seq_tokens = T / B;
-
-    // Set up shared memory pointers
-    float* _k = shared_mem;
-    float* _r = _k + head_size;
-    float* _tf = _r + head_size;
-    float* _td = _tf + head_size;
-
-    // Local state array
-    float state[WKV_BLOCK_SIZE];
-
-    // Load initial state
-    #pragma unroll
-    for (int i = 0; i < head_size; i++) {
-        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
-    }
-
-    // Sync threads before shared memory operations
-    item_ct1.barrier(sycl::access::fence_space::local_space);
-
-    // Load time-mixing parameters
-    _tf[tid] = tf[head_i * head_size + tid];
-    item_ct1.barrier(sycl::access::fence_space::local_space);
-
-    // Main sequence processing loop
-    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
-         t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
-         t += C) {
-
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-
-        // Load current timestep data to shared memory
-        _k[tid] = k[t];
-        _r[tid] = r[t];
-        _td[tid] = td[t];
-
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-
-        const float _v = v[t];
-        float y = 0;
-
-        // Process in chunks of 4 for better vectorization
-        sycl::float4 k4, r4, tf4, td4, s4;
-        #pragma unroll
-        for (int j = 0; j < head_size; j += 4) {
-            // Load data in vec4 chunks
-            k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
-            r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
-            tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
-            td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
-            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
-
-            // Compute key-value product
-            sycl::float4 kv4 = k4 * _v;
-
-            // Accumulate weighted sum
-            y += sycl::dot(r4, tf4 * kv4 + s4);
-
-            // Update state
-            s4 = s4 * td4 + kv4;
-
-            // Store updated state
-            state[j] = s4.x();
-            state[j+1] = s4.y();
-            state[j+2] = s4.z();
-            state[j+3] = s4.w();
-        }
-
-        dst[t] = y;
-    }
-
-    // Save final state
-    #pragma unroll
-    for (int i = 0; i < head_size; i++) {
-        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
-    }
-}
-
-void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
-
-    const ggml_tensor *src0 = dst->src[0];
-    const ggml_tensor *src1 = dst->src[1];
-
-    const float* k_d = (const float*)dst->src[0]->data;
-    const float* v_d = (const float*)dst->src[1]->data;
-    const float* r_d = (const float*)dst->src[2]->data;
-    const float* tf_d = (const float*)dst->src[3]->data;
-    const float* td_d = (const float*)dst->src[4]->data;
-    const float* s_d = (const float*)dst->src[5]->data;
-    float* dst_d = (float*)dst->data;
-
-    const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[2];
-    const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[1];
-
-    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
-    GGML_ASSERT(C % H == 0);
-    GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
-
-    dpct::queue_ptr stream = ctx.stream();
-
-    // Calculate execution configuration
-    const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
-    sycl::range<3> block_dims(1, 1, C / H);
-    sycl::range<3> grid_dims(1, 1, B * H);
-
-    // Submit kernel
-    stream->submit([&](sycl::handler& cgh) {
-        sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
-
-        cgh.parallel_for(
-            sycl::nd_range<3>(grid_dims * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rwkv_wkv_f32_kernel(
-                    B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
-                    item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
-                );
-            });
-    });
-
-    GGML_UNUSED(src0);
-    GGML_UNUSED(src1);
-}
diff --git a/ggml/src/ggml-sycl/wkv6.hpp b/ggml/src/ggml-sycl/wkv6.hpp
deleted file mode 100644 (file)
index 8c596a9..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-#ifndef GGML_SYCL_WKV6_HPP
-#define GGML_SYCL_WKV6_HPP
-
-#include "common.hpp"
-
-void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
-
-
-#endif // GGML_SYCL_WKV6_HPP
index 97398f071b80ed04367113e2003c308f0f543601..c0ee5dadef78a346fd7af16c5cd6fe109f88b480 100644 (file)
@@ -304,6 +304,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_group_norm_f32;
     vk_pipeline pipeline_rms_norm_f32;
     vk_pipeline pipeline_rms_norm_back_f32;
+    vk_pipeline pipeline_l2_norm_f32;
     vk_pipeline pipeline_gelu_f32;
     vk_pipeline pipeline_gelu_quick_f32;
     vk_pipeline pipeline_silu_f32;
@@ -328,6 +329,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_timestep_embedding_f32;
     vk_pipeline pipeline_pool2d_f32;
     vk_pipeline pipeline_rwkv_wkv6_f32;
+    vk_pipeline pipeline_rwkv_wkv7_f32;
     vk_pipeline pipeline_opt_step_adamw_f32;
 
     // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
@@ -629,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants {
     uint32_t H;
 };
 
+struct vk_op_rwkv_wkv7_push_constants {
+    uint32_t B;
+    uint32_t T;
+    uint32_t C;
+    uint32_t H;
+};
+
 // Allow pre-recording command buffers
 struct vk_staging_memcpy {
     vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -2263,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2374,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
+
     ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
     for (auto &c : compiles) {
@@ -5473,6 +5485,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_rms_norm_back_f32;
         }
         return nullptr;
+    case GGML_OP_L2_NORM:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_l2_norm_f32;
+        }
+        return nullptr;
     case GGML_OP_UNARY:
         switch (ggml_get_unary_op(dst)) {
             case GGML_UNARY_OP_SILU:
@@ -5612,6 +5629,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_rwkv_wkv6_f32;
         }
         return nullptr;
+    case GGML_OP_RWKV_WKV7:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_rwkv_wkv7_f32;
+        }
+        return nullptr;
     case GGML_OP_OPT_STEP_ADAMW:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_opt_step_adamw_f32;
@@ -5859,6 +5881,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     case GGML_OP_NORM:
     case GGML_OP_RMS_NORM:
     case GGML_OP_RMS_NORM_BACK:
+    case GGML_OP_L2_NORM:
     case GGML_OP_SOFT_MAX:
     case GGML_OP_SOFT_MAX_BACK:
     case GGML_OP_SUM_ROWS:
@@ -6108,23 +6131,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
     }, dryrun);
 }
 
-static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
-    const ggml_tensor * k = dst->src[0];
-    const ggml_tensor * v = dst->src[1];
-    const ggml_tensor * r = dst->src[2];
-    const ggml_tensor * tf = dst->src[3];
-    const ggml_tensor * td = dst->src[4];
-    const ggml_tensor * state = dst->src[5];
-
-    GGML_ASSERT(!ggml_is_quantized(k->type));
-    GGML_ASSERT(!ggml_is_quantized(v->type));
-    GGML_ASSERT(!ggml_is_quantized(r->type));
-    GGML_ASSERT(!ggml_is_quantized(tf->type));
-    GGML_ASSERT(!ggml_is_quantized(td->type));
-    GGML_ASSERT(!ggml_is_quantized(state->type));
+static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
+    GGML_ASSERT(version == 6 || version == 7);
+    int num_srcs = version == 6 ? 6 : 7;
+
+    for (int i = 0; i < num_srcs; i++) {
+        GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
+    }
+
     GGML_ASSERT(dst->buffer != nullptr);
 
-    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
+    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
     GGML_ASSERT(pipeline != nullptr);
 
     if (dryrun) {
@@ -6133,89 +6150,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
     }
 
     ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
-    ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
-    ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
-    ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
-    ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
-    ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
-    ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
+    ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
+    for (int i = 0; i < num_srcs; i++) {
+        src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
+    }
 
     ggml_vk_sync_buffers(subctx);
 
-    vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
-    size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
-    bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
+    vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
+    size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
+    bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
 
     if (ctx->device->uma) {
-        ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
-        ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
-        ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
-        ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
-        ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
-        ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
-        ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
+        for (int i = 0; i < num_srcs; i++) {
+            ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
+            srcs_uma[i] = d_srcs[i] != nullptr;
+        }
 
-        K_uma = d_K != nullptr;
-        V_uma = d_V != nullptr;
-        R_uma = d_R != nullptr;
-        TF_uma = d_TF != nullptr;
-        TD_uma = d_TD != nullptr;
-        STATE_uma = d_State != nullptr;
-        DST_uma = d_D != nullptr;
+        ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
+        dst_uma = d_D != nullptr;
     }
 
-    if (!K_uma) {
-        d_K = k_buf_ctx->dev_buffer;
-        k_offset = vk_tensor_offset(k) + k->view_offs;
-    }
-    if (!V_uma) {
-        d_V = v_buf_ctx->dev_buffer;
-        v_offset = vk_tensor_offset(v) + v->view_offs;
-    }
-    if (!R_uma) {
-        d_R = r_buf_ctx->dev_buffer;
-        r_offset = vk_tensor_offset(r) + r->view_offs;
-    }
-    if (!TF_uma) {
-        d_TF = tf_buf_ctx->dev_buffer;
-        tf_offset = vk_tensor_offset(tf) + tf->view_offs;
-    }
-    if (!TD_uma) {
-        d_TD = td_buf_ctx->dev_buffer;
-        td_offset = vk_tensor_offset(td) + td->view_offs;
-    }
-    if (!STATE_uma) {
-        d_State = state_buf_ctx->dev_buffer;
-        state_offset = vk_tensor_offset(state) + state->view_offs;
+    uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
+    for (int i = 0; i < num_srcs; i++) {
+        src_sizes[i] = ggml_nbytes(dst->src[i]);
+        if (!srcs_uma[i]) {
+            d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
+            src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
+        }
     }
-    if (!DST_uma) {
+
+    const uint64_t dst_size = ggml_nbytes(dst);
+    if (!dst_uma) {
         d_D = dst_buf_ctx->dev_buffer;
         dst_offset = vk_tensor_offset(dst) + dst->view_offs;
     }
 
-    const uint64_t k_size = ggml_nbytes(k);
-    const uint64_t v_size = ggml_nbytes(v);
-    const uint64_t r_size = ggml_nbytes(r);
-    const uint64_t tf_size = ggml_nbytes(tf);
-    const uint64_t td_size = ggml_nbytes(td);
-    const uint64_t state_size = ggml_nbytes(state);
-    const uint64_t dst_size = ggml_nbytes(dst);
-
     std::array<uint32_t, 3> elements = {
         (uint32_t)(pc.B * pc.H),
         1,
         1
     };
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
-        vk_subbuffer{ d_K, k_offset, k_size },
-        vk_subbuffer{ d_V, v_offset, v_size },
-        vk_subbuffer{ d_R, r_offset, r_size },
-        vk_subbuffer{ d_TF, tf_offset, tf_size },
-        vk_subbuffer{ d_TD, td_offset, td_size },
-        vk_subbuffer{ d_State, state_offset, state_size },
-        vk_subbuffer{ d_D, dst_offset, dst_size }
-    }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
+    if (version == 6) {
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
+            vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
+            vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
+            vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
+            vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
+            vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
+            vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
+            vk_subbuffer{ d_D, dst_offset, dst_size }
+        }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
+    } else if (version == 7) {
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
+            vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
+            vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
+            vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
+            vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
+            vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
+            vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
+            vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
+            vk_subbuffer{ d_D, dst_offset, dst_size }
+        }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
+    } else {
+        // shouldn't happen
+        GGML_ASSERT(false);
+    }
 }
 
 static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@@ -6224,7 +6225,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
     const size_t n_heads = dst->src[0]->ne[1];
     const size_t n_seqs = dst->src[5]->ne[1];
 
-    ggml_vk_op_f32_rwkv6(
+    ggml_vk_op_f32_wkv(
+        ctx, subctx, dst,
+        {
+            (uint32_t)n_seqs,
+            (uint32_t)seq_length,
+            (uint32_t)n_embed,
+            (uint32_t)n_heads,
+        },
+        6,
+        dryrun
+    );
+}
+
+static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
+    const size_t seq_length = dst->src[0]->ne[2];
+    const size_t n_embed = dst->ne[0];
+    const size_t n_heads = dst->src[0]->ne[1];
+    const size_t n_seqs = dst->src[6]->ne[1];
+
+    ggml_vk_op_f32_wkv(
         ctx, subctx, dst,
         {
             (uint32_t)n_seqs,
@@ -6232,6 +6252,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
             (uint32_t)n_embed,
             (uint32_t)n_heads,
         },
+        7,
         dryrun
     );
 }
@@ -6533,6 +6554,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
 }
 
+static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+    float * op_params = (float *)dst->op_params;
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
+}
+
 static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
 }
@@ -7528,6 +7554,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_GROUP_NORM:
     case GGML_OP_RMS_NORM:
     case GGML_OP_RMS_NORM_BACK:
+    case GGML_OP_L2_NORM:
     case GGML_OP_DIAG_MASK_INF:
     case GGML_OP_SOFT_MAX:
     case GGML_OP_SOFT_MAX_BACK:
@@ -7544,6 +7571,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_TIMESTEP_EMBEDDING:
     case GGML_OP_POOL_2D:
     case GGML_OP_RWKV_WKV6:
+    case GGML_OP_RWKV_WKV7:
     case GGML_OP_LEAKY_RELU:
     case GGML_OP_FLASH_ATTN_EXT:
     case GGML_OP_OPT_STEP_ADAMW:
@@ -7590,6 +7618,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
         case GGML_OP_GROUP_NORM:
         case GGML_OP_RMS_NORM:
         case GGML_OP_RMS_NORM_BACK:
+        case GGML_OP_L2_NORM:
         case GGML_OP_UNARY:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
@@ -7707,6 +7736,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_RMS_NORM_BACK:
         ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
 
+        break;
+    case GGML_OP_L2_NORM:
+        ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
+
         break;
     case GGML_OP_UNARY:
         switch (ggml_get_unary_op(node)) {
@@ -7797,6 +7830,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
 
         break;
 
+    case GGML_OP_RWKV_WKV7:
+        ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
+
+        break;
+
     case GGML_OP_OPT_STEP_ADAMW:
         ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
 
@@ -7870,6 +7908,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
     case GGML_OP_GROUP_NORM:
     case GGML_OP_RMS_NORM:
     case GGML_OP_RMS_NORM_BACK:
+    case GGML_OP_L2_NORM:
     case GGML_OP_DIAG_MASK_INF:
     case GGML_OP_SOFT_MAX:
     case GGML_OP_SOFT_MAX_BACK:
@@ -7889,6 +7928,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
     case GGML_OP_TIMESTEP_EMBEDDING:
     case GGML_OP_POOL_2D:
     case GGML_OP_RWKV_WKV6:
+    case GGML_OP_RWKV_WKV7:
     case GGML_OP_LEAKY_RELU:
     case GGML_OP_REPEAT:
     case GGML_OP_REPEAT_BACK:
@@ -8806,6 +8846,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_NORM:
         case GGML_OP_GROUP_NORM:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_L2_NORM:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_ADD:
         case GGML_OP_SUB:
@@ -8835,6 +8876,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_POOL_2D:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_RWKV_WKV7:
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_OPT_STEP_ADAMW:
             return true;
@@ -9219,6 +9261,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
     } else if (tensor->op == GGML_OP_SILU_BACK) {
         tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
+    } else if (tensor->op == GGML_OP_L2_NORM) {
+        const float eps = ((float *) tensor->op_params)[0];
+        tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
     } else if (tensor->op == GGML_OP_SOFT_MAX) {
         if (src1 != nullptr) {
             tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
@@ -9338,6 +9383,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
     } else if (tensor->op == GGML_OP_RWKV_WKV6) {
         tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
         src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
+    } else if (tensor->op == GGML_OP_RWKV_WKV7) {
+        tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
+        src_clone[4], src_clone[5], src_clone[6]);
     } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
         src_clone[0]->flags = src0->flags;
         tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
new file mode 100644 (file)
index 0000000..deba8c3
--- /dev/null
@@ -0,0 +1,41 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+shared FLOAT_TYPE sum[BLOCK_SIZE];
+
+void main() {
+    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+    const uint tid = gl_LocalInvocationID.x;
+
+    sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
+
+    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
+        sum[tid] += xi * xi;
+    }
+
+    // sum up partial sums and write back result
+    barrier();
+    [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) {
+            sum[tid] += sum[tid + s];
+        }
+        barrier();
+    }
+
+    const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
+
+    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+        data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
+    }
+}
index ee1fec4e114aace0cc00c069d76724c375157874..eb2ad63ff6bf039f64e5c7b4b30f0bb464b6ecaa 100644 (file)
@@ -434,6 +434,7 @@ void process_shaders() {
     string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 
     string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
@@ -528,6 +529,8 @@ void process_shaders() {
 
     string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
+    string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+
     string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
     for (auto &c : compiles) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp
new file mode 100644 (file)
index 0000000..88c1c02
--- /dev/null
@@ -0,0 +1,91 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+#define BLOCK_SIZE 64
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout(push_constant) uniform Parameters {
+    uint B;
+    uint T;
+    uint C;
+    uint H;
+};
+
+layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
+layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
+layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
+layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
+layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
+layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
+layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
+layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
+
+shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
+
+void main() {
+    const uint head_size = BLOCK_SIZE;
+    const uint batch_id = gl_WorkGroupID.x / H;
+    const uint head_id = gl_WorkGroupID.x % H;
+    const uint tid = gl_LocalInvocationID.x;
+
+    const uint state_size = C * head_size;
+    const uint n_seq_tokens = T / B;
+
+    if (batch_id >= B || head_id >= H) {
+        return;
+    }
+
+    A_TYPE state[BLOCK_SIZE];
+    [[unroll]] for (uint i = 0; i < head_size; i++) {
+        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+                          + tid * head_size + i];
+    }
+
+    const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
+    const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
+
+    for (uint t = start_t; t < end_t; t += C) {
+        barrier();
+        _r[tid] = r[t];
+        _w[tid] = w[t];
+        _k[tid] = k[t];
+        _a[tid] = a[t];
+        _b[tid] = b[t];
+        barrier();
+
+        A_TYPE sa = 0.0;
+        [[unroll]] for (uint j = 0; j < head_size; j += 4) {
+            vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
+            vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
+            sa += dot(s_vec, a_vec);
+        }
+
+        const A_TYPE v_val = v[t];
+        A_TYPE y = 0.0;
+
+        [[unroll]] for (uint j = 0; j < head_size; j += 4) {
+            vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+            vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
+            vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+            vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
+            vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
+
+            vec4 kv = k_vec * v_val;
+            s_vec = s_vec * w_vec + kv + sa * b_vec;
+            y += dot(r_vec, s_vec);
+
+            state[j] = s_vec.x;
+            state[j+1] = s_vec.y;
+            state[j+2] = s_vec.z;
+            state[j+3] = s_vec.w;
+        }
+
+        dst[t] = y;
+    }
+
+    [[unroll]] for (uint i = 0; i < head_size; i++) {
+        dst[T * C + batch_id * state_size + head_id * head_size * head_size
+            + tid * head_size + i] = state[i];
+    }
+}
index 89409bb0e42a54a05dd97c606be4156f576aa108..2e081d5910c6ed4a23d4c69daa93224bd488ef22 100644 (file)
@@ -929,6 +929,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "RMS_NORM",
     "RMS_NORM_BACK",
     "GROUP_NORM",
+    "L2_NORM",
 
     "MUL_MAT",
     "MUL_MAT_ID",
@@ -977,6 +978,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ADD_REL_POS",
     "RWKV_WKV6",
     "GATED_LINEAR_ATTN",
+    "RWKV_WKV7",
 
     "UNARY",
 
@@ -996,7 +998,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1026,6 +1028,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "rms_norm(x)",
     "rms_norm_back(x)",
     "group_norm(x)",
+    "l2_norm(x)",
 
     "X*Y",
     "X[i]*Y",
@@ -1074,6 +1077,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "add_rel_pos(x)",
     "rwkv_wkv6(k, v, r, tf, td, s)",
     "gated_linear_attn(k, v, q, gate, s)",
+    "rwkv_wkv7(r, w, k, v, a, b, s)",
 
     "unary(x)",
 
@@ -1093,7 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -2686,6 +2690,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
     return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
 }
 
+// ggml_l2_norm
+
+static struct ggml_tensor * ggml_l2_norm_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params_f32(result, 0, eps);
+
+    result->op     = GGML_OP_L2_NORM;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_l2_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps) {
+    return ggml_l2_norm_impl(ctx, a, eps, false);
+}
+
+struct ggml_tensor * ggml_l2_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps) {
+    return ggml_l2_norm_impl(ctx, a, eps, true);
+}
+
 // ggml_mul_mat
 
 static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -4720,6 +4755,54 @@ struct ggml_tensor * ggml_gated_linear_attn(
     return result;
 }
 
+// ggml_rwkv_wkv7
+
+struct ggml_tensor * ggml_rwkv_wkv7(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * r,
+        struct ggml_tensor  * w,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * state) {
+    GGML_ASSERT(ggml_is_contiguous(r));
+    GGML_ASSERT(ggml_is_contiguous(w));
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_is_contiguous(b));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S = k->ne[0];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
+    const int64_t n_seqs = state->ne[1];
+    {
+        GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
+        GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
+        GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
+        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    }
+
+    // concat output and new_state
+    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op     = GGML_OP_RWKV_WKV7;
+    result->src[0] = r;
+    result->src[1] = w;
+    result->src[2] = k;
+    result->src[3] = v;
+    result->src[4] = a;
+    result->src[5] = b;
+    result->src[6] = state;
+
+    return result;
+}
+
 // ggml_unary
 
 static struct ggml_tensor * ggml_unary_impl(
index 19624eae04ece3db0c2d8d3438025d1416116ae6..cc48913d9789d5e1ae5d5d7b8441ee74f51c5a99 100644 (file)
@@ -118,22 +118,26 @@ class Keys:
         TOKEN_SHIFT_COUNT                 = "{arch}.token_shift_count"
 
     class Attention:
-        HEAD_COUNT        = "{arch}.attention.head_count"
-        HEAD_COUNT_KV     = "{arch}.attention.head_count_kv"
-        MAX_ALIBI_BIAS    = "{arch}.attention.max_alibi_bias"
-        CLAMP_KQV         = "{arch}.attention.clamp_kqv"
-        KEY_LENGTH        = "{arch}.attention.key_length"
-        VALUE_LENGTH      = "{arch}.attention.value_length"
-        LAYERNORM_EPS     = "{arch}.attention.layer_norm_epsilon"
-        LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
-        GROUPNORM_EPS     = "{arch}.attention.group_norm_epsilon"
-        GROUPNORM_GROUPS  = "{arch}.attention.group_norm_groups"
-        CAUSAL            = "{arch}.attention.causal"
-        Q_LORA_RANK       = "{arch}.attention.q_lora_rank"
-        KV_LORA_RANK      = "{arch}.attention.kv_lora_rank"
-        REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
-        SLIDING_WINDOW    = "{arch}.attention.sliding_window"
-        SCALE             = "{arch}.attention.scale"
+        HEAD_COUNT                   = "{arch}.attention.head_count"
+        HEAD_COUNT_KV                = "{arch}.attention.head_count_kv"
+        MAX_ALIBI_BIAS               = "{arch}.attention.max_alibi_bias"
+        CLAMP_KQV                    = "{arch}.attention.clamp_kqv"
+        KEY_LENGTH                   = "{arch}.attention.key_length"
+        VALUE_LENGTH                 = "{arch}.attention.value_length"
+        LAYERNORM_EPS                = "{arch}.attention.layer_norm_epsilon"
+        LAYERNORM_RMS_EPS            = "{arch}.attention.layer_norm_rms_epsilon"
+        GROUPNORM_EPS                = "{arch}.attention.group_norm_epsilon"
+        GROUPNORM_GROUPS             = "{arch}.attention.group_norm_groups"
+        CAUSAL                       = "{arch}.attention.causal"
+        Q_LORA_RANK                  = "{arch}.attention.q_lora_rank"
+        KV_LORA_RANK                 = "{arch}.attention.kv_lora_rank"
+        DECAY_LORA_RANK              = "{arch}.attention.decay_lora_rank"
+        ICLR_LORA_RANK               = "{arch}.attention.iclr_lora_rank"
+        VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank"
+        GATE_LORA_RANK               = "{arch}.attention.gate_lora_rank"
+        REL_BUCKETS_COUNT            = "{arch}.attention.relative_buckets_count"
+        SLIDING_WINDOW               = "{arch}.attention.sliding_window"
+        SCALE                        = "{arch}.attention.scale"
 
     class Rope:
         DIMENSION_COUNT         = "{arch}.rope.dimension_count"
@@ -257,6 +261,8 @@ class MODEL_ARCH(IntEnum):
     STARCODER2       = auto()
     RWKV6            = auto()
     RWKV6QWEN2       = auto()
+    RWKV7            = auto()
+    ARWKV7           = auto()
     MAMBA            = auto()
     XVERSE           = auto()
     COMMAND_R        = auto()
@@ -329,8 +335,20 @@ class MODEL_TENSOR(IntEnum):
     SSM_A                = auto()
     SSM_D                = auto()
     SSM_OUT              = auto()
+    TIME_MIX_W0          = auto()
     TIME_MIX_W1          = auto()
     TIME_MIX_W2          = auto()
+    TIME_MIX_A0          = auto()
+    TIME_MIX_A1          = auto()
+    TIME_MIX_A2          = auto()
+    TIME_MIX_V0          = auto()
+    TIME_MIX_V1          = auto()
+    TIME_MIX_V2          = auto()
+    TIME_MIX_G1          = auto()
+    TIME_MIX_G2          = auto()
+    TIME_MIX_K_K         = auto()
+    TIME_MIX_K_A         = auto()
+    TIME_MIX_R_K         = auto()
     TIME_MIX_LERP_X      = auto()
     TIME_MIX_LERP_K      = auto()
     TIME_MIX_LERP_V      = auto()
@@ -445,6 +463,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.STARCODER2:       "starcoder2",
     MODEL_ARCH.RWKV6:            "rwkv6",
     MODEL_ARCH.RWKV6QWEN2:       "rwkv6qwen2",
+    MODEL_ARCH.RWKV7:            "rwkv7",
+    MODEL_ARCH.ARWKV7:           "arwkv7",
     MODEL_ARCH.MAMBA:            "mamba",
     MODEL_ARCH.XVERSE:           "xverse",
     MODEL_ARCH.COMMAND_R:        "command-r",
@@ -517,8 +537,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.SSM_A:                     "blk.{bid}.ssm_a",
     MODEL_TENSOR.SSM_D:                     "blk.{bid}.ssm_d",
     MODEL_TENSOR.SSM_OUT:                   "blk.{bid}.ssm_out",
+    MODEL_TENSOR.TIME_MIX_W0:               "blk.{bid}.time_mix_w0",
     MODEL_TENSOR.TIME_MIX_W1:               "blk.{bid}.time_mix_w1",
     MODEL_TENSOR.TIME_MIX_W2:               "blk.{bid}.time_mix_w2",
+    MODEL_TENSOR.TIME_MIX_A0:               "blk.{bid}.time_mix_a0",
+    MODEL_TENSOR.TIME_MIX_A1:               "blk.{bid}.time_mix_a1",
+    MODEL_TENSOR.TIME_MIX_A2:               "blk.{bid}.time_mix_a2",
+    MODEL_TENSOR.TIME_MIX_V0:               "blk.{bid}.time_mix_v0",
+    MODEL_TENSOR.TIME_MIX_V1:               "blk.{bid}.time_mix_v1",
+    MODEL_TENSOR.TIME_MIX_V2:               "blk.{bid}.time_mix_v2",
+    MODEL_TENSOR.TIME_MIX_G1:               "blk.{bid}.time_mix_g1",
+    MODEL_TENSOR.TIME_MIX_G2:               "blk.{bid}.time_mix_g2",
+    MODEL_TENSOR.TIME_MIX_K_K:              "blk.{bid}.time_mix_k_k",
+    MODEL_TENSOR.TIME_MIX_K_A:              "blk.{bid}.time_mix_k_a",
+    MODEL_TENSOR.TIME_MIX_R_K:              "blk.{bid}.time_mix_r_k",
     MODEL_TENSOR.TIME_MIX_LERP_X:           "blk.{bid}.time_mix_lerp_x",
     MODEL_TENSOR.TIME_MIX_LERP_K:           "blk.{bid}.time_mix_lerp_k",
     MODEL_TENSOR.TIME_MIX_LERP_V:           "blk.{bid}.time_mix_lerp_v",
@@ -1172,6 +1204,68 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.RWKV7: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_NORM_2,
+        MODEL_TENSOR.TIME_MIX_LERP_FUSED,
+        MODEL_TENSOR.TIME_MIX_W0,
+        MODEL_TENSOR.TIME_MIX_W1,
+        MODEL_TENSOR.TIME_MIX_W2,
+        MODEL_TENSOR.TIME_MIX_A0,
+        MODEL_TENSOR.TIME_MIX_A1,
+        MODEL_TENSOR.TIME_MIX_A2,
+        MODEL_TENSOR.TIME_MIX_V0,
+        MODEL_TENSOR.TIME_MIX_V1,
+        MODEL_TENSOR.TIME_MIX_V2,
+        MODEL_TENSOR.TIME_MIX_G1,
+        MODEL_TENSOR.TIME_MIX_G2,
+        MODEL_TENSOR.TIME_MIX_K_K,
+        MODEL_TENSOR.TIME_MIX_K_A,
+        MODEL_TENSOR.TIME_MIX_R_K,
+        MODEL_TENSOR.TIME_MIX_KEY,
+        MODEL_TENSOR.TIME_MIX_VALUE,
+        MODEL_TENSOR.TIME_MIX_RECEPTANCE,
+        MODEL_TENSOR.TIME_MIX_LN,
+        MODEL_TENSOR.TIME_MIX_OUTPUT,
+        MODEL_TENSOR.CHANNEL_MIX_LERP_K,
+        MODEL_TENSOR.CHANNEL_MIX_KEY,
+        MODEL_TENSOR.CHANNEL_MIX_VALUE,
+    ],
+    MODEL_ARCH.ARWKV7: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.TIME_MIX_LERP_FUSED,
+        MODEL_TENSOR.TIME_MIX_W0,
+        MODEL_TENSOR.TIME_MIX_W1,
+        MODEL_TENSOR.TIME_MIX_W2,
+        MODEL_TENSOR.TIME_MIX_A0,
+        MODEL_TENSOR.TIME_MIX_A1,
+        MODEL_TENSOR.TIME_MIX_A2,
+        MODEL_TENSOR.TIME_MIX_V0,
+        MODEL_TENSOR.TIME_MIX_V1,
+        MODEL_TENSOR.TIME_MIX_V2,
+        MODEL_TENSOR.TIME_MIX_G1,
+        MODEL_TENSOR.TIME_MIX_G2,
+        MODEL_TENSOR.TIME_MIX_K_K,
+        MODEL_TENSOR.TIME_MIX_K_A,
+        MODEL_TENSOR.TIME_MIX_R_K,
+        MODEL_TENSOR.TIME_MIX_KEY,
+        MODEL_TENSOR.TIME_MIX_VALUE,
+        MODEL_TENSOR.TIME_MIX_RECEPTANCE,
+        MODEL_TENSOR.TIME_MIX_LN,
+        MODEL_TENSOR.TIME_MIX_OUTPUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.MAMBA: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 080d2b9dce5cb36dc0ee64ce19f53e06690be8ad..af8b388dfaba5b13f798a67708320ab5bf84b6a4 100644 (file)
@@ -767,6 +767,18 @@ class GGUFWriter:
     def add_kv_lora_rank(self, length: int) -> None:
         self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
 
+    def add_decay_lora_rank(self, length: int) -> None:
+        self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
+
+    def add_iclr_lora_rank(self, length: int) -> None:
+        self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
+
+    def add_value_residual_mix_lora_rank(self, length: int) -> None:
+        self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
+
+    def add_gate_lora_rank(self, length: int) -> None:
+        self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
+
     def add_relative_attn_buckets_count(self, value: int) -> None:
         self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
 
index 617791e240b6053055209db9790ad365ad34f6d0..8d4a2b032018395dab06356b5b92117fa5f023e4 100644 (file)
@@ -27,7 +27,8 @@ class TensorNameMap:
             "embedding.word_embeddings",                 # chatglm
             "transformer.token_embeddings",              # openelm
             "shared",                                    # t5
-            "rwkv.embeddings",                           # rwkv
+            "rwkv.embeddings",                           # rwkv6
+            "model.embeddings",                          # rwkv7
         ),
 
         # Token type embeddings
@@ -42,6 +43,9 @@ class TensorNameMap:
             "emb_ln",                     # nomic-bert
             "transformer.norm",           # openelm
             "rwkv.blocks.0.pre_ln",       # rwkv
+            "rwkv.blocks.0.pre_ln",       # rwkv6
+            "model.pre_ln",               # rwkv7
+            "model.layers.0.pre_norm",    # rwkv7
             "backbone.norm",              # wavtokenizer
         ),
 
@@ -81,7 +85,8 @@ class TensorNameMap:
             "encoder.final_layernorm",                 # chatglm
             "transformer.norm",                        # openelm
             "model.norm",                              # nemotron
-            "rwkv.ln_out",                             # rwkv
+            "rwkv.ln_out",                             # rwkv6
+            "model.ln_out",                            # rwkv7
             "backbone.final_layer_norm",               # wavtokenizer
         ),
 
@@ -122,14 +127,16 @@ class TensorNameMap:
             "transformer.blocks.{bid}.norm_attn_norm.norm_1",       # dbrx
             "encoder.layers.{bid}.input_layernorm",                 # chatglm
             "transformer.layers.{bid}.attn_norm",                   # openelm
-            "rwkv.blocks.{bid}.ln1",                                # rwkv
+            "rwkv.blocks.{bid}.ln1",                                # rwkv6
+            "model.layers.{bid}.ln1",                               # rwkv7
         ),
 
         # Attention norm 2
         MODEL_TENSOR.ATTN_NORM_2: (
             "transformer.h.{bid}.ln_attn",                  # falcon40b
             "encoder.layer.{bid}.layer_norm_1",             # jina-v2-code
-            "rwkv.blocks.{bid}.ln2",                        # rwkv
+            "rwkv.blocks.{bid}.ln2",                        # rwkv6
+            "model.layers.{bid}.ln2",                       # rwkv7
         ),
 
         # Attention query-key-value
@@ -462,112 +469,174 @@ class TensorNameMap:
             "backbone.layers.{bid}.mixer.out_proj",
         ),
 
+        MODEL_TENSOR.TIME_MIX_W0: (
+            "model.layers.{bid}.attention.w0",            # rwkv7
+        ),
+
         MODEL_TENSOR.TIME_MIX_W1: (
-            "rwkv.blocks.{bid}.attention.time_maa_w1",  # rwkv v6
-            "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
+            "rwkv.blocks.{bid}.attention.time_maa_w1",    # rwkv6
+            "model.layers.{bid}.self_attn.time_maa_w1",   # rwkv6qwen2
+            "model.layers.{bid}.attention.w1",            # rwkv7
         ),
 
         MODEL_TENSOR.TIME_MIX_W2: (
-            "rwkv.blocks.{bid}.attention.time_maa_w2",  # rwkv v6
-            "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
+            "rwkv.blocks.{bid}.attention.time_maa_w2",    # rwkv6
+            "model.layers.{bid}.self_attn.time_maa_w2",   # rwkv6qwen2
+            "model.layers.{bid}.attention.w2",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_A0: (
+            "model.layers.{bid}.attention.a0",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_A1: (
+            "model.layers.{bid}.attention.a1",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_A2: (
+            "model.layers.{bid}.attention.a2",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_V0: (
+            "model.layers.{bid}.attention.v0",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_V1: (
+            "model.layers.{bid}.attention.v1",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_V2: (
+            "model.layers.{bid}.attention.v2",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_G1: (
+            "model.layers.{bid}.attention.g1",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_G2: (
+            "model.layers.{bid}.attention.g2",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_K_K: (
+            "model.layers.{bid}.attention.k_k",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_K_A: (
+            "model.layers.{bid}.attention.k_a",            # rwkv7
+        ),
+
+        MODEL_TENSOR.TIME_MIX_R_K: (
+            "model.layers.{bid}.attention.r_k",            # rwkv7
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_X: (
-            "rwkv.blocks.{bid}.attention.time_maa_x",   # rwkv v6
+            "rwkv.blocks.{bid}.attention.time_maa_x",   # rwkv6
             "model.layers.{bid}.self_attn.time_maa_x",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_K: (
-            "rwkv.blocks.{bid}.attention.time_maa_k",   # rwkv v6
+            "rwkv.blocks.{bid}.attention.time_maa_k",   # rwkv6
             "model.layers.{bid}.self_attn.time_maa_k",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_V: (
-            "rwkv.blocks.{bid}.attention.time_maa_v",   # rwkv v6
+            "rwkv.blocks.{bid}.attention.time_maa_v",   # rwkv6
             "model.layers.{bid}.self_attn.time_maa_v",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_R: (
-            "rwkv.blocks.{bid}.attention.time_maa_r",   # rwkv v6
+            "rwkv.blocks.{bid}.attention.time_maa_r",   # rwkv6
             "model.layers.{bid}.self_attn.time_maa_r",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_G: (
-            "rwkv.blocks.{bid}.attention.time_maa_g",   # rwkv v6
+            "rwkv.blocks.{bid}.attention.time_maa_g",   # rwkv6
             "model.layers.{bid}.self_attn.time_maa_g",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_W: (
-            "rwkv.blocks.{bid}.attention.time_maa_w",   # rwkv v6
+            "rwkv.blocks.{bid}.attention.time_maa_w",   # rwkv6
             "model.layers.{bid}.self_attn.time_maa_w",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_FIRST: (
-            "rwkv.blocks.{bid}.attention.time_faaaa",   # rwkv v6
+            "rwkv.blocks.{bid}.attention.time_faaaa",   # rwkv6
         ),
 
         MODEL_TENSOR.TIME_MIX_DECAY: (
-            "rwkv.blocks.{bid}.attention.time_decay",   # rwkv v6
+            "rwkv.blocks.{bid}.attention.time_decay",   # rwkv6
             "model.layers.{bid}.self_attn.time_decay",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_DECAY_W1: (
-            "rwkv.blocks.{bid}.attention.time_decay_w1",  # rwkv v6
+            "rwkv.blocks.{bid}.attention.time_decay_w1",  # rwkv6
             "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_DECAY_W2: (
-            "rwkv.blocks.{bid}.attention.time_decay_w2",  # rwkv v6
+            "rwkv.blocks.{bid}.attention.time_decay_w2",  # rwkv6
             "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_KEY: (
-            "rwkv.blocks.{bid}.attention.key",     # rwkv
+            "rwkv.blocks.{bid}.attention.key",     # rwkv6
             "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
+            "model.layers.{bid}.attention.key",    # rwkv7
+            "model.layers.{bid}.attention.k_proj", # rwkv7
         ),
 
         MODEL_TENSOR.TIME_MIX_VALUE: (
-            "rwkv.blocks.{bid}.attention.value",   # rwkv
+            "rwkv.blocks.{bid}.attention.value",   # rwkv6
             "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
+            "model.layers.{bid}.attention.value",  # rwkv7
+            "model.layers.{bid}.attention.v_proj", # rwkv7
         ),
 
         MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
-            "rwkv.blocks.{bid}.attention.receptance", # rwkv
-            "model.layers.{bid}.self_attn.q_proj",    # rwkv6qwen2
+            "rwkv.blocks.{bid}.attention.receptance",  # rwkv6
+            "model.layers.{bid}.self_attn.q_proj",     # rwkv6qwen2
+            "model.layers.{bid}.attention.receptance", # rwkv7
+            "model.layers.{bid}.attention.r_proj",     # rwkv7
         ),
 
         MODEL_TENSOR.TIME_MIX_GATE: (
-            "rwkv.blocks.{bid}.attention.gate",  # rwkv
-            "model.layers.{bid}.self_attn.gate", # rwkv6qwen2
+            "rwkv.blocks.{bid}.attention.gate",        # rwkv6
+            "model.layers.{bid}.self_attn.gate",       # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LN: (
-            "rwkv.blocks.{bid}.attention.ln_x", # rwkv
+            "rwkv.blocks.{bid}.attention.ln_x", # rwkv6
+            "model.layers.{bid}.attention.ln_x" # rwkv7
         ),
 
         MODEL_TENSOR.TIME_MIX_OUTPUT: (
-            "rwkv.blocks.{bid}.attention.output",  # rwkv
+            "rwkv.blocks.{bid}.attention.output",  # rwkv6
             "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
+            "model.layers.{bid}.attention.output", # rwkv7
+            "model.layers.{bid}.attention.o_proj", # rwkv7
         ),
 
         MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
-            "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
+            "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6
+            "model.layers.{bid}.feed_forward.x_k",       # rwkv7
         ),
 
         MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
-            "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6
+            "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6
         ),
 
         MODEL_TENSOR.CHANNEL_MIX_KEY: (
-            "rwkv.blocks.{bid}.feed_forward.key", # rwkv
+            "rwkv.blocks.{bid}.feed_forward.key",  # rwkv6
+            "model.layers.{bid}.feed_forward.key", # rwkv7
         ),
 
         MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
-            "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv
+            "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv6
         ),
 
         MODEL_TENSOR.CHANNEL_MIX_VALUE: (
-            "rwkv.blocks.{bid}.feed_forward.value", # rwkv
+            "rwkv.blocks.{bid}.feed_forward.value",  # rwkv6
+            "model.layers.{bid}.feed_forward.value", # rwkv7
         ),
 
         MODEL_TENSOR.ATTN_Q_A: (
index 28f2bbc8f72bffc39562b260a5d5b9458715b31f..9debb56cc80d543c1b13e2e05fba11f1e7d6c233 100644 (file)
@@ -59,6 +59,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_EXAONE,           "exaone"           },
     { LLM_ARCH_RWKV6,            "rwkv6"            },
     { LLM_ARCH_RWKV6QWEN2,       "rwkv6qwen2"       },
+    { LLM_ARCH_RWKV7,            "rwkv7"            },
+    { LLM_ARCH_ARWKV7,           "arwkv7"           },
     { LLM_ARCH_GRANITE,          "granite"          },
     { LLM_ARCH_GRANITE_MOE,      "granitemoe"       },
     { LLM_ARCH_CHAMELEON,        "chameleon"        },
@@ -110,22 +112,26 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_EMBEDDING_SCALE,                   "%s.embedding_scale"                   },
     { LLM_KV_TOKEN_SHIFT_COUNT,                 "%s.token_shift_count"                 },
 
-    { LLM_KV_ATTENTION_HEAD_COUNT,             "%s.attention.head_count"             },
-    { LLM_KV_ATTENTION_HEAD_COUNT_KV,          "%s.attention.head_count_kv"          },
-    { LLM_KV_ATTENTION_MAX_ALIBI_BIAS,         "%s.attention.max_alibi_bias"         },
-    { LLM_KV_ATTENTION_CLAMP_KQV,              "%s.attention.clamp_kqv"              },
-    { LLM_KV_ATTENTION_KEY_LENGTH,             "%s.attention.key_length"             },
-    { LLM_KV_ATTENTION_VALUE_LENGTH,           "%s.attention.value_length"           },
-    { LLM_KV_ATTENTION_LAYERNORM_EPS,          "%s.attention.layer_norm_epsilon"     },
-    { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,      "%s.attention.layer_norm_rms_epsilon" },
-    { LLM_KV_ATTENTION_GROUPNORM_EPS,          "%s.attention.group_norm_epsilon"     },
-    { LLM_KV_ATTENTION_GROUPNORM_GROUPS,       "%s.attention.group_norm_groups"      },
-    { LLM_KV_ATTENTION_CAUSAL,                 "%s.attention.causal"                 },
-    { LLM_KV_ATTENTION_Q_LORA_RANK,            "%s.attention.q_lora_rank"            },
-    { LLM_KV_ATTENTION_KV_LORA_RANK,           "%s.attention.kv_lora_rank"           },
-    { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
-    { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
-    { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
+    { LLM_KV_ATTENTION_HEAD_COUNT,                   "%s.attention.head_count"                   },
+    { LLM_KV_ATTENTION_HEAD_COUNT_KV,                "%s.attention.head_count_kv"                },
+    { LLM_KV_ATTENTION_MAX_ALIBI_BIAS,               "%s.attention.max_alibi_bias"               },
+    { LLM_KV_ATTENTION_CLAMP_KQV,                    "%s.attention.clamp_kqv"                    },
+    { LLM_KV_ATTENTION_KEY_LENGTH,                   "%s.attention.key_length"                   },
+    { LLM_KV_ATTENTION_VALUE_LENGTH,                 "%s.attention.value_length"                 },
+    { LLM_KV_ATTENTION_LAYERNORM_EPS,                "%s.attention.layer_norm_epsilon"           },
+    { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,            "%s.attention.layer_norm_rms_epsilon"       },
+    { LLM_KV_ATTENTION_GROUPNORM_EPS,                "%s.attention.group_norm_epsilon"           },
+    { LLM_KV_ATTENTION_GROUPNORM_GROUPS,             "%s.attention.group_norm_groups"            },
+    { LLM_KV_ATTENTION_CAUSAL,                       "%s.attention.causal"                       },
+    { LLM_KV_ATTENTION_Q_LORA_RANK,                  "%s.attention.q_lora_rank"                  },
+    { LLM_KV_ATTENTION_KV_LORA_RANK,                 "%s.attention.kv_lora_rank"                 },
+    { LLM_KV_ATTENTION_DECAY_LORA_RANK,              "%s.attention.decay_lora_rank"              },
+    { LLM_KV_ATTENTION_ICLR_LORA_RANK,               "%s.attention.iclr_lora_rank"               },
+    { LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
+    { LLM_KV_ATTENTION_GATE_LORA_RANK,               "%s.attention.gate_lora_rank"               },
+    { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,       "%s.attention.relative_buckets_count"       },
+    { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
+    { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
 
     { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
     { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
@@ -1238,6 +1244,74 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,                    "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_RWKV7,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,                "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM,           "token_embd_norm" },
+            { LLM_TENSOR_OUTPUT_NORM,               "output_norm" },
+            { LLM_TENSOR_OUTPUT,                    "output" },
+            { LLM_TENSOR_ATTN_NORM,                 "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_NORM_2,               "blk.%d.attn_norm_2" },
+            { LLM_TENSOR_TIME_MIX_W0,               "blk.%d.time_mix_w0" },
+            { LLM_TENSOR_TIME_MIX_W1,               "blk.%d.time_mix_w1" },
+            { LLM_TENSOR_TIME_MIX_W2,               "blk.%d.time_mix_w2" },
+            { LLM_TENSOR_TIME_MIX_A0,               "blk.%d.time_mix_a0" },
+            { LLM_TENSOR_TIME_MIX_A1,               "blk.%d.time_mix_a1" },
+            { LLM_TENSOR_TIME_MIX_A2,               "blk.%d.time_mix_a2" },
+            { LLM_TENSOR_TIME_MIX_V0,               "blk.%d.time_mix_v0" },
+            { LLM_TENSOR_TIME_MIX_V1,               "blk.%d.time_mix_v1" },
+            { LLM_TENSOR_TIME_MIX_V2,               "blk.%d.time_mix_v2" },
+            { LLM_TENSOR_TIME_MIX_G1,               "blk.%d.time_mix_g1" },
+            { LLM_TENSOR_TIME_MIX_G2,               "blk.%d.time_mix_g2" },
+            { LLM_TENSOR_TIME_MIX_K_K,              "blk.%d.time_mix_k_k" },
+            { LLM_TENSOR_TIME_MIX_K_A,              "blk.%d.time_mix_k_a" },
+            { LLM_TENSOR_TIME_MIX_R_K,              "blk.%d.time_mix_r_k" },
+            { LLM_TENSOR_TIME_MIX_LERP_FUSED,       "blk.%d.time_mix_lerp_fused" },
+            { LLM_TENSOR_TIME_MIX_KEY,              "blk.%d.time_mix_key" },
+            { LLM_TENSOR_TIME_MIX_VALUE,            "blk.%d.time_mix_value" },
+            { LLM_TENSOR_TIME_MIX_RECEPTANCE,       "blk.%d.time_mix_receptance" },
+            { LLM_TENSOR_TIME_MIX_LN,               "blk.%d.time_mix_ln" },
+            { LLM_TENSOR_TIME_MIX_OUTPUT,           "blk.%d.time_mix_output" },
+            { LLM_TENSOR_CHANNEL_MIX_LERP_K,        "blk.%d.channel_mix_lerp_k" },
+            { LLM_TENSOR_CHANNEL_MIX_KEY,           "blk.%d.channel_mix_key" },
+            { LLM_TENSOR_CHANNEL_MIX_VALUE,         "blk.%d.channel_mix_value" },
+        },
+    },
+    {
+        LLM_ARCH_ARWKV7,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,                "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM,           "token_embd_norm" },
+            { LLM_TENSOR_OUTPUT_NORM,               "output_norm" },
+            { LLM_TENSOR_OUTPUT,                    "output" },
+            { LLM_TENSOR_ATTN_NORM,                 "blk.%d.attn_norm" },
+            { LLM_TENSOR_TIME_MIX_W0,               "blk.%d.time_mix_w0" },
+            { LLM_TENSOR_TIME_MIX_W1,               "blk.%d.time_mix_w1" },
+            { LLM_TENSOR_TIME_MIX_W2,               "blk.%d.time_mix_w2" },
+            { LLM_TENSOR_TIME_MIX_A0,               "blk.%d.time_mix_a0" },
+            { LLM_TENSOR_TIME_MIX_A1,               "blk.%d.time_mix_a1" },
+            { LLM_TENSOR_TIME_MIX_A2,               "blk.%d.time_mix_a2" },
+            { LLM_TENSOR_TIME_MIX_V0,               "blk.%d.time_mix_v0" },
+            { LLM_TENSOR_TIME_MIX_V1,               "blk.%d.time_mix_v1" },
+            { LLM_TENSOR_TIME_MIX_V2,               "blk.%d.time_mix_v2" },
+            { LLM_TENSOR_TIME_MIX_G1,               "blk.%d.time_mix_g1" },
+            { LLM_TENSOR_TIME_MIX_G2,               "blk.%d.time_mix_g2" },
+            { LLM_TENSOR_TIME_MIX_K_K,              "blk.%d.time_mix_k_k" },
+            { LLM_TENSOR_TIME_MIX_K_A,              "blk.%d.time_mix_k_a" },
+            { LLM_TENSOR_TIME_MIX_R_K,              "blk.%d.time_mix_r_k" },
+            { LLM_TENSOR_TIME_MIX_LERP_FUSED,       "blk.%d.time_mix_lerp_fused" },
+            { LLM_TENSOR_TIME_MIX_KEY,              "blk.%d.time_mix_key" },
+            { LLM_TENSOR_TIME_MIX_VALUE,            "blk.%d.time_mix_value" },
+            { LLM_TENSOR_TIME_MIX_RECEPTANCE,       "blk.%d.time_mix_receptance" },
+            { LLM_TENSOR_TIME_MIX_LN,               "blk.%d.time_mix_ln" },
+            { LLM_TENSOR_TIME_MIX_OUTPUT,           "blk.%d.time_mix_output" },
+            { LLM_TENSOR_FFN_NORM,                  "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,                  "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,                  "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,                    "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_GRANITE,
         {
@@ -1397,6 +1471,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_SSM_OUT,                    {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_TIME_MIX_W1,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_TIME_MIX_W2,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_A1,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_A2,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_V1,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_V2,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_G1,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_TIME_MIX_G2,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_TIME_MIX_DECAY_W1,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_TIME_MIX_DECAY_W2,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_TIME_MIX_KEY,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@@ -1415,6 +1495,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_TIME_MIX_LN,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_CHANNEL_MIX_LERP_K,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_CHANNEL_MIX_LERP_R,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_TIME_MIX_K_K,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_TIME_MIX_K_A,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_TIME_MIX_R_K,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_TIME_MIX_LERP_W,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     {LLM_TENSOR_TIME_MIX_LERP_K,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     {LLM_TENSOR_TIME_MIX_LERP_V,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
@@ -1422,6 +1505,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_TIME_MIX_LERP_G,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     {LLM_TENSOR_TIME_MIX_LERP_FUSED,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     {LLM_TENSOR_TIME_MIX_DECAY,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
+    {LLM_TENSOR_TIME_MIX_W0,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
+    {LLM_TENSOR_TIME_MIX_A0,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
+    {LLM_TENSOR_TIME_MIX_V0,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     {LLM_TENSOR_TIME_MIX_FIRST,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
     {LLM_TENSOR_ATTN_NORM,                  {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_ATTN_NORM_2,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
index 2ec2e2362eba1a3b0892e905235dfefb63bdbab0..a28815d8a14c7bdadf4f79e471bc85cd8568ac80 100644 (file)
@@ -63,6 +63,8 @@ enum llm_arch {
     LLM_ARCH_EXAONE,
     LLM_ARCH_RWKV6,
     LLM_ARCH_RWKV6QWEN2,
+    LLM_ARCH_RWKV7,
+    LLM_ARCH_ARWKV7,
     LLM_ARCH_GRANITE,
     LLM_ARCH_GRANITE_MOE,
     LLM_ARCH_CHAMELEON,
@@ -127,6 +129,10 @@ enum llm_kv {
     LLM_KV_ATTENTION_CAUSAL,
     LLM_KV_ATTENTION_Q_LORA_RANK,
     LLM_KV_ATTENTION_KV_LORA_RANK,
+    LLM_KV_ATTENTION_DECAY_LORA_RANK,
+    LLM_KV_ATTENTION_ICLR_LORA_RANK,
+    LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
+    LLM_KV_ATTENTION_GATE_LORA_RANK,
     LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
     LLM_KV_ATTENTION_SLIDING_WINDOW,
     LLM_KV_ATTENTION_SCALE,
@@ -250,8 +256,20 @@ enum llm_tensor {
     LLM_TENSOR_SSM_A,
     LLM_TENSOR_SSM_D,
     LLM_TENSOR_SSM_OUT,
+    LLM_TENSOR_TIME_MIX_W0,
     LLM_TENSOR_TIME_MIX_W1,
     LLM_TENSOR_TIME_MIX_W2,
+    LLM_TENSOR_TIME_MIX_A0,
+    LLM_TENSOR_TIME_MIX_A1,
+    LLM_TENSOR_TIME_MIX_A2,
+    LLM_TENSOR_TIME_MIX_V0,
+    LLM_TENSOR_TIME_MIX_V1,
+    LLM_TENSOR_TIME_MIX_V2,
+    LLM_TENSOR_TIME_MIX_G1,
+    LLM_TENSOR_TIME_MIX_G2,
+    LLM_TENSOR_TIME_MIX_K_K,
+    LLM_TENSOR_TIME_MIX_K_A,
+    LLM_TENSOR_TIME_MIX_R_K,
     LLM_TENSOR_TIME_MIX_LERP_X,
     LLM_TENSOR_TIME_MIX_LERP_W,
     LLM_TENSOR_TIME_MIX_LERP_K,
index dbb7abd317b6f7b86f5511074e942c3c2c04c18a..bb17ba86dc2fb40371df8be010f58e303ac24060 100644 (file)
@@ -76,6 +76,10 @@ struct llama_hparams {
     uint32_t time_decay_extra_dim   = 0;
     uint32_t wkv_head_size          = 0;
     uint32_t token_shift_count      = 2;
+    uint32_t n_lora_decay           = 0;
+    uint32_t n_lora_iclr            = 0;
+    uint32_t n_lora_value_res_mix   = 0;
+    uint32_t n_lora_gate            = 0;
 
     float    rope_attn_factor = 1.0f;
     float    rope_freq_base_train;
index 4b288d8f66a33c716c1c111edf6dd565797d61b5..c571aa69b671c480e39facc238b4bd13fe332989 100644 (file)
@@ -32,6 +32,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_109M:          return "109M";
         case LLM_TYPE_137M:          return "137M";
         case LLM_TYPE_160M:          return "160M";
+        case LLM_TYPE_190M:          return "190M";
         case LLM_TYPE_220M:          return "220M";
         case LLM_TYPE_250M:          return "250M";
         case LLM_TYPE_270M:          return "270M";
@@ -48,6 +49,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_1_6B:          return "1.6B";
         case LLM_TYPE_2B:            return "2B";
         case LLM_TYPE_2_8B:          return "2.8B";
+        case LLM_TYPE_2_9B:          return "2.9B";
         case LLM_TYPE_3B:            return "3B";
         case LLM_TYPE_4B:            return "4B";
         case LLM_TYPE_6B:            return "6B";
@@ -1250,6 +1252,36 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_RWKV7:
+        case LLM_ARCH_ARWKV7:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,                hparams.f_norm_eps, false);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,            hparams.f_norm_rms_eps, false);
+                ml.get_key(LLM_KV_WKV_HEAD_SIZE,                          hparams.wkv_head_size);
+                ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK,              hparams.n_lora_decay);
+                ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK,               hparams.n_lora_iclr);
+                ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix);
+                ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK,               hparams.n_lora_gate, false);
+                ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT,                      hparams.token_shift_count, false);
+
+                switch (hparams.n_layer) {
+                    case 12: type = LLM_TYPE_190M; break;
+                    case 24:
+                        switch (hparams.n_embd) {
+                            case 1024: type = LLM_TYPE_450M; break;
+                            case 2048: type = LLM_TYPE_1_5B; break;
+                            default: type = LLM_TYPE_UNKNOWN;
+                        } break;
+                    case 28:
+                        switch (hparams.n_embd) {
+                            case 1536: type = LLM_TYPE_1_5B; break;
+                            case 3584: type = LLM_TYPE_7B; break;
+                            default: type = LLM_TYPE_UNKNOWN;
+                        } break;
+                    case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_GRANITE:
         case LLM_ARCH_GRANITE_MOE:
             {
@@ -3366,6 +3398,146 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_RWKV7:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // Block 0, LN0
+                    tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    const int n_lora_decay = hparams.n_lora_decay;
+                    const int n_lora_iclr = hparams.n_lora_iclr;
+                    const int n_lora_value_res_mix = hparams.n_lora_value_res_mix;
+                    const int n_lora_gate = hparams.n_lora_gate;
+                    const int attn_hidden_size = n_embd;
+                    const int ffn_size = hparams.n_ff_arr[0];
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, 0);
+
+                        layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0);
+                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0);
+                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0);
+
+                        layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0);
+                        layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0);
+                        layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0);
+
+                        if (i == 0) {
+                            // actually not used
+                            layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
+                            layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0);
+                            layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0);
+                        } else {
+                            layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
+                            layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0);
+                            layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
+                        }
+
+                        layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0);
+                        layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0);
+
+                        layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
+
+                        layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
+                        layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);
+                        layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0);
+
+                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+
+                        layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
+                        layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
+                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+                        layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
+
+                        layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
+                        layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
+                    }
+
+                } break;
+            case LLM_ARCH_ARWKV7:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    const int n_lora_decay = hparams.n_lora_decay;
+                    const int n_lora_iclr = hparams.n_lora_iclr;
+                    const int n_lora_value_res_mix = hparams.n_lora_value_res_mix;
+                    const int n_lora_gate = hparams.n_lora_gate;
+                    const int attn_hidden_size = n_embd;
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0);
+                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0);
+                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0);
+
+                        layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0);
+                        layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0);
+                        layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0);
+
+                        if (i == 0) {
+                            // actually not used
+                            layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
+                            layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0);
+                            layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0);
+                        } else {
+                            layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
+                            layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0);
+                            layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
+                        }
+
+                        layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        try {
+                            layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
+                        } catch(std::runtime_error & e) {
+                            // ARWKV models may not have gate tensors
+                            layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
+                        }
+
+                        layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
+                        layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);
+                        layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0);
+
+                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+
+                        layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+
+                } break;
             case LLM_ARCH_CHAMELEON:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -10212,6 +10384,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
+        const auto n_seq_tokens = ubatch.n_seq_tokens;
         const auto n_embd = hparams.n_embd;
         const auto head_size = hparams.wkv_head_size;
         const auto n_head = n_embd / head_size;
@@ -10224,6 +10397,10 @@ struct llm_build_rwkv6_base : public llm_graph_context {
         bool is_qrwkv = layer.time_mix_first == nullptr;
 
         ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
+
+        sx  = ggml_reshape_2d(ctx0, sx,  n_embd, n_tokens);
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+
         ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur);
 
         xxx = ggml_reshape_4d(
@@ -10366,7 +10543,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
         cur = ggml_mul(ctx0, cur, g);
         cur = build_lora_mm(layer.time_mix_output, cur);
 
-        return cur;
+        return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
     }
 };
 
@@ -10389,6 +10566,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
 
         for (int il = 0; il < n_layer; ++il) {
             const llama_layer * layer = &model.layers[il];
+            inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
             ggml_tensor * token_shift = build_rwkv_token_shift_load(
                     gf, state_copy, state_mask, ubatch, il
@@ -10422,9 +10600,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
-            cur = ggml_add(ctx0, cur, ffn_inp);
-
             token_shift = ggml_concat(ctx0,
                     ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
                     ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
@@ -10432,6 +10607,18 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
                     );
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
 
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                ffn_inp  = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp,  n_embd, n_tokens), inp_out_ids);
+                ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
+                x_prev   = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev,   n_embd, n_tokens), inp_out_ids);
+                cur      = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur,      n_embd, n_tokens), inp_out_ids);
+            }
+
+            cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
             if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
                 cur = ggml_scale(ctx0, cur, 0.5F);
             }
@@ -10444,12 +10631,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
         }
 
         cur = inpL;
-
-        ggml_tensor * inp_out_ids = build_inp_out_ids();
-
-        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
-        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
-
         cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
 
         cb(cur, "result_norm", -1);
@@ -10481,10 +10662,9 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
         const auto n_seq_tokens = ubatch.n_seq_tokens;
         const auto n_seqs = ubatch.n_seqs;
 
-        inpL = build_inp_embd(model.tok_embd);
-
         for (int il = 0; il < n_layer; ++il) {
             const llama_layer * layer = &model.layers[il];
+            inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
             ggml_tensor * token_shift = build_rwkv_token_shift_load(
                     gf, state_copy, state_mask, ubatch, il
@@ -10508,6 +10688,13 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
 
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur     = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
+                ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
+            }
+
             // feed-forward network
             cur = build_norm(ffn_inp,
                     model.layers[il].ffn_norm, NULL,
@@ -10532,10 +10719,358 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
         }
 
         cur = inpL;
-        ggml_tensor * inp_out_ids = build_inp_out_ids();
-        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
-        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+        cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_rwkv7_base : public llm_graph_context {
+    const llama_model & model;
+
+    llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
+    }
+
+    ggml_tensor * build_rwkv7_channel_mix(
+            const llama_layer * layer,
+            ggml_tensor * cur,
+            ggml_tensor * x_prev,
+            llm_arch arch) const {
+        ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
+        switch (arch) {
+            case LLM_ARCH_RWKV7:
+                {
+                    ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
+
+                    ggml_tensor * k = ggml_sqr(
+                        ctx0,
+                        ggml_relu(
+                            ctx0,
+                            build_lora_mm(layer->channel_mix_key, xk)
+                        )
+                    );
+
+                    cur = build_lora_mm(layer->channel_mix_value, k);
+                } break;
+            default:
+                GGML_ABORT("fatal error");
+        }
+
+        return cur;
+    }
+
+    ggml_tensor * build_rwkv7_time_mix(
+            ggml_cgraph * gf,
+            ggml_tensor * cur,
+            ggml_tensor * x_prev,
+            ggml_tensor * state_copy,
+            ggml_tensor * state_mask,
+            ggml_tensor *& first_layer_value,
+            const llama_ubatch & ubatch,
+            int   il) const {
+        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+        const auto n_tokens = ubatch.n_tokens;
+        const auto n_seqs = ubatch.n_seqs;
+        const auto n_embd = hparams.n_embd;
+        const auto head_size = hparams.wkv_head_size;
+        const auto head_count = n_embd / head_size;
+        const auto n_seq_tokens = ubatch.n_seq_tokens;
+
+        const auto kv_head = kv_self->head;
+
+        const auto & layer = model.layers[il];
+
+        bool has_gating = layer.time_mix_g1 && layer.time_mix_g2;
+
+        ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
+        ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5);
+        sx = ggml_repeat(ctx0, sx, dummy);
+
+        ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur);
+
+        ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        ggml_tensor * xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        ggml_tensor * xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        ggml_tensor * xa = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+        ggml_tensor * xg = has_gating ? ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float)) : nullptr;
+
+        ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr);
+        ggml_tensor * w = ggml_add(
+            ctx0,
+            ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))),
+            layer.time_mix_w0
+        );
+        w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531));
+
+        ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk);
+        ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv);
+        if (first_layer_value == nullptr) {
+            first_layer_value = v;
+        } else {
+            // Add the first layer value as a residual connection.
+            v = ggml_add(ctx0, v,
+                ggml_mul(ctx0,
+                    ggml_sub(ctx0, first_layer_value, v),
+                    ggml_sigmoid(ctx0, ggml_add(ctx0,
+                            ggml_mul_mat(ctx0, layer.time_mix_v2, ggml_mul_mat(ctx0, layer.time_mix_v1, xv)),
+                            layer.time_mix_v0
+                        )
+                    )
+                )
+            );
+        }
+
+        ggml_tensor * g = nullptr;
+        if (layer.time_mix_g1 && layer.time_mix_g2) {
+            g = ggml_mul_mat(ctx0, layer.time_mix_g2, ggml_sigmoid(ctx0, ggml_mul_mat(ctx0, layer.time_mix_g1, xg)));
+        }
+
+        ggml_tensor * a = ggml_sigmoid(ctx0,
+            ggml_add(
+                ctx0,
+                ggml_mul_mat(ctx0, layer.time_mix_a2, ggml_mul_mat(ctx0, layer.time_mix_a1, xa)),
+                layer.time_mix_a0
+            )
+        );
+
+        ggml_tensor * kk = ggml_reshape_3d(ctx0, ggml_mul(ctx0, k, layer.time_mix_k_k), head_size, head_count, n_tokens);
+        kk = ggml_l2_norm(ctx0, kk, 1e-12);
 
+        ggml_tensor * ka = ggml_mul(ctx0, k, layer.time_mix_k_a);
+        k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka));
+
+        r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens);
+        w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens);
+        k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens);
+        v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
+        a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
+
+        ggml_tensor * wkv_state = build_copy_mask_state(
+                gf, kv_self->v_l[il], state_copy, state_mask,
+                hparams.n_embd_v_s(), n_seqs);
+
+        ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
+        cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
+        wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
+
+        ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_state,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self->v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
+                        )
+                    )
+                );
+
+        if (layer.time_mix_ln && layer.time_mix_ln_b) {
+            // group norm with head_count groups
+            cur = ggml_reshape_3d(ctx0, cur, n_embd / head_count, head_count, n_tokens);
+            cur = ggml_norm(ctx0, cur, 64e-5f);
+
+            // Convert back to regular vectors.
+            cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b);
+        } else {
+            cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        }
+
+        ggml_tensor * rk = ggml_sum_rows(ctx0,
+                ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count)));
+        cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens));
+
+        if (has_gating) {
+            cur = ggml_mul(ctx0, cur, g);
+        }
+        cur = build_lora_mm(layer.time_mix_output, cur);
+
+        return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
+    }
+};
+
+struct llm_build_rwkv7 : public llm_build_rwkv7_base {
+    llm_build_rwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
+        GGML_ASSERT(hparams.token_shift_count == 2);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        ggml_tensor * v_first = nullptr;
+
+        inpL = build_inp_embd(model.tok_embd);
+        inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
+
+        ggml_tensor * state_copy = build_inp_s_copy();
+        ggml_tensor * state_mask = build_inp_s_mask();
+
+        const auto n_embd = hparams.n_embd;
+        const auto n_seq_tokens = ubatch.n_seq_tokens;
+        const auto n_seqs = ubatch.n_seqs;
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+            inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(
+                    gf, state_copy, state_mask, ubatch, il
+                    );
+
+            ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
+            ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
+
+            ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il);
+            cb(att_norm, "attn_norm", il);
+
+            ggml_tensor * x_prev = ggml_concat(
+                    ctx0,
+                    att_shift,
+                    ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
+                    1
+                    );
+
+            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il);
+            cb(ffn_norm, "ffn_norm", il);
+
+            x_prev = ggml_concat(
+                    ctx0,
+                    ffn_shift,
+                    ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0),
+                    1
+                    );
+
+            token_shift = ggml_concat(ctx0,
+                    ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
+                    ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
+                    1
+                    );
+            ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                ffn_inp  = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp,  n_embd, n_tokens), inp_out_ids);
+                ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
+                x_prev   = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev,   n_embd, n_tokens), inp_out_ids);
+            }
+
+            cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+
+struct llm_build_arwkv7 : public llm_build_rwkv7_base {
+    llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
+        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        ggml_tensor * v_first = nullptr;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        ggml_tensor * state_copy = build_inp_s_copy();
+        ggml_tensor * state_mask = build_inp_s_mask();
+
+        const auto n_embd = hparams.n_embd;
+        const auto n_seq_tokens = ubatch.n_seq_tokens;
+        const auto n_seqs = ubatch.n_seqs;
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+            inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(
+                    gf, state_copy, state_mask, ubatch, il
+                    );
+
+            ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
+            cb(att_norm, "attn_norm", il);
+
+            ggml_tensor * x_prev = ggml_concat(
+                    ctx0,
+                    token_shift,
+                    ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
+                    1
+                    );
+
+            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
+
+            token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
+            ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur     = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
+                ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
+            }
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
         cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
 
         cb(cur, "result_norm", -1);
@@ -10883,9 +11418,11 @@ llama_memory_i * llama_model::create_memory() const {
     llama_memory_i * res;
 
     switch (arch) {
+        case LLM_ARCH_MAMBA:
         case LLM_ARCH_RWKV6:
         case LLM_ARCH_RWKV6QWEN2:
-        case LLM_ARCH_MAMBA:
+        case LLM_ARCH_RWKV7:
+        case LLM_ARCH_ARWKV7:
             {
                 res = new llama_kv_cache_unified(hparams, {
                     /*.get_rope_factors =*/ nullptr
@@ -11132,6 +11669,14 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params, gf);
             } break;
+        case LLM_ARCH_RWKV7:
+            {
+                llm = std::make_unique<llm_build_rwkv7>(*this, params, gf);
+            } break;
+        case LLM_ARCH_ARWKV7:
+            {
+                llm = std::make_unique<llm_build_arwkv7>(*this, params, gf);
+            } break;
         case LLM_ARCH_CHAMELEON:
             {
                 llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
@@ -11245,6 +11790,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_JAIS:
         case LLM_ARCH_RWKV6:
         case LLM_ARCH_RWKV6QWEN2:
+        case LLM_ARCH_RWKV7:
+        case LLM_ARCH_ARWKV7:
         case LLM_ARCH_WAVTOKENIZER_DEC:
             return LLAMA_ROPE_TYPE_NONE;
 
@@ -11399,6 +11946,8 @@ bool llama_model_is_recurrent(const llama_model * model) {
         case     LLM_ARCH_MAMBA:      return true;
         case     LLM_ARCH_RWKV6:      return true;
         case     LLM_ARCH_RWKV6QWEN2: return true;
+        case     LLM_ARCH_RWKV7:      return true;
+        case     LLM_ARCH_ARWKV7:     return true;
         default: return false;
     }
 }
index 55c26a92b02d25656142badd7c0d5e8781bce606..a9da1215abbfd3013342f099067d50f16af57086 100644 (file)
@@ -29,6 +29,7 @@ enum llm_type {
     LLM_TYPE_109M,
     LLM_TYPE_137M,
     LLM_TYPE_160M,
+    LLM_TYPE_190M,
     LLM_TYPE_220M,
     LLM_TYPE_250M,
     LLM_TYPE_270M,
@@ -45,6 +46,7 @@ enum llm_type {
     LLM_TYPE_1_6B,
     LLM_TYPE_2B,
     LLM_TYPE_2_8B,
+    LLM_TYPE_2_9B,
     LLM_TYPE_3B,
     LLM_TYPE_4B,
     LLM_TYPE_6B,
@@ -260,6 +262,20 @@ struct llama_layer {
     struct ggml_tensor * time_mix_receptance_b = nullptr;
     struct ggml_tensor * time_mix_gate         = nullptr;
 
+    // rwkv7
+    struct ggml_tensor * time_mix_w0         = nullptr;
+    struct ggml_tensor * time_mix_a0         = nullptr;
+    struct ggml_tensor * time_mix_a1         = nullptr;
+    struct ggml_tensor * time_mix_a2         = nullptr;
+    struct ggml_tensor * time_mix_v0         = nullptr;
+    struct ggml_tensor * time_mix_v1         = nullptr;
+    struct ggml_tensor * time_mix_v2         = nullptr;
+    struct ggml_tensor * time_mix_g1         = nullptr;
+    struct ggml_tensor * time_mix_g2         = nullptr;
+    struct ggml_tensor * time_mix_k_k        = nullptr;
+    struct ggml_tensor * time_mix_k_a        = nullptr;
+    struct ggml_tensor * time_mix_r_k        = nullptr;
+
     struct ggml_tensor * time_mix_ln     = nullptr;
     struct ggml_tensor * time_mix_ln_b   = nullptr;
     struct ggml_tensor * time_mix_output = nullptr;
index fb7982655a373f3cc9f4e65b72a5e819534a8478..09eb570779ce538ffaaf4c4385d388ce63fa87bf 100644 (file)
@@ -756,10 +756,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         // NOTE: can't use LLM_TN here because the layer number is not known
         quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
 
-        // do not quantize RWKV's time_mix_first tensors
+        // do not quantize RWKV's small yet 2D weights
         quantize &= name.find("time_mix_first.weight") == std::string::npos;
+        quantize &= name.find("time_mix_w0.weight") == std::string::npos;
         quantize &= name.find("time_mix_w1.weight") == std::string::npos;
         quantize &= name.find("time_mix_w2.weight") == std::string::npos;
+        quantize &= name.find("time_mix_v0.weight") == std::string::npos;
+        quantize &= name.find("time_mix_v1.weight") == std::string::npos;
+        quantize &= name.find("time_mix_v2.weight") == std::string::npos;
+        quantize &= name.find("time_mix_a0.weight") == std::string::npos;
+        quantize &= name.find("time_mix_a1.weight") == std::string::npos;
+        quantize &= name.find("time_mix_a2.weight") == std::string::npos;
+        quantize &= name.find("time_mix_g1.weight") == std::string::npos;
+        quantize &= name.find("time_mix_g2.weight") == std::string::npos;
         quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
         quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
         quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
index c86ffb64e9e890063f899202b18faaa1fbbdc0f7..adb749bd5ec9a6337e4343030b036274fbafff49 100644 (file)
@@ -1916,6 +1916,40 @@ struct test_gla : public test_case {
     }
 };
 
+// GGML_OP_RWKV_WKV7
+struct test_rwkv_wkv7 : public test_case {
+    const ggml_type type;
+
+    const int64_t head_count;
+    const int64_t head_size;
+    const int64_t n_seq_tokens;
+    const int64_t n_seqs;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
+    }
+
+    test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,
+            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
+        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const int64_t n_tokens = n_seq_tokens * n_seqs;
+        ggml_tensor * r   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * w   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * a   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * b   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        // Outputs may become NaN with long seqlen without these normalization
+        a = ggml_l2_norm(ctx, a, 1e-7F);
+        b = ggml_l2_norm(ctx, b, 1e-7F);
+        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
+        ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
+        return out;
+    }
+};
+
 // GGML_OP_MUL_MAT
 struct test_mul_mat : public test_case {
     const ggml_type type_a;
@@ -2972,6 +3006,32 @@ struct test_group_norm : public test_case {
     }
 };
 
+// GGML_OP_L2_NORM
+struct test_l2_norm : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_l2_norm(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {64, 64, 320, 1},
+            float eps = 1e-12f)
+        : type(type), ne(ne), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // GGML_OP_ACC
 struct test_acc : public test_case {
     const ggml_type type;
@@ -4006,8 +4066,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
             test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
         }
         test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+        test_cases.emplace_back(new test_l2_norm      (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
     }
 
+    test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
+
     test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
     test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
     test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
@@ -4019,6 +4082,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
 
+    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));
+    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));
+    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));
+    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));
+
     test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
     test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
     test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));