]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama: add support for QRWKV6 model architecture (#11001)
authorMolly Sophia <redacted>
Fri, 10 Jan 2025 01:58:08 +0000 (09:58 +0800)
committerGitHub <redacted>
Fri, 10 Jan 2025 01:58:08 +0000 (09:58 +0800)
llama: add support for QRWKV6 model architecture (#11001)

* WIP: Add support for RWKV6Qwen2

Signed-off-by: Molly Sophia <redacted>
* RWKV: Some graph simplification

Signed-off-by: Molly Sophia <redacted>
* Add support for RWKV6Qwen2 with cpu and cuda GLA

Signed-off-by: Molly Sophia <redacted>
* RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead

Signed-off-by: Molly Sophia <redacted>
* Fix some typos

Signed-off-by: Molly Sophia <redacted>
* code format changes

Signed-off-by: Molly Sophia <redacted>
* Fix wkv test & add gla test

Signed-off-by: Molly Sophia <redacted>
* Fix cuda warning

Signed-off-by: Molly Sophia <redacted>
* Update README.md

Signed-off-by: Molly Sophia <redacted>
* Update ggml/src/ggml-cuda/gla.cu

Co-authored-by: Georgi Gerganov <redacted>
* Fix fused lerp weights loading with RWKV6

Signed-off-by: Molly Sophia <redacted>
* better sanity check skipping for QRWKV6 in llama-quant

thanks @compilade

Signed-off-by: Molly Sophia <redacted>
Co-authored-by: compilade <redacted>
---------

Signed-off-by: Molly Sophia <redacted>
Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: compilade <redacted>
23 files changed:
README.md
convert_hf_to_gguf.py
ggml/include/ggml.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/gla.cu [new file with mode: 0644]
ggml/src/ggml-cuda/gla.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/wkv6.cu
ggml/src/ggml-sycl/wkv6.cpp
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml.c
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-model.cpp
src/llama-model.h
src/llama-quant.cpp
src/llama.cpp
tests/test-backend-ops.cpp

index a7101525648b0bf750119c6968bd1ef80865ca93..6302ac977aa94011bee99c63f18223809e675bc0 100644 (file)
--- a/README.md
+++ b/README.md
@@ -99,6 +99,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
 - [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
 - [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
 - [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
+- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1)
 - [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
 
 #### Multimodal
index 5562499aa4925849d6c515c93ba63a0fd5524496..cf317eeae608aa280f41a1f5f5b7cd19a253dac3 100755 (executable)
@@ -326,6 +326,7 @@ class Model:
                             gguf.MODEL_TENSOR.TIME_MIX_W2,
                             gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
                             gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
+                            gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
                             gguf.MODEL_TENSOR.POSNET_NORM1,
                             gguf.MODEL_TENSOR.POSNET_NORM2,
                         )
@@ -3316,6 +3317,8 @@ class Rwkv6Model(Model):
         # required by llama.cpp, unused
         self.gguf_writer.add_head_count(0)
 
+    lerp_weights: dict[int, dict[str, Tensor]] = {}
+
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         new_name = self.map_tensor_name(name)
 
@@ -3331,14 +3334,84 @@ class Rwkv6Model(Model):
         if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
             data_torch = data_torch.squeeze()
 
-        rescale_every_n_layers = self.hparams["rescale_every"]
-        if rescale_every_n_layers > 0:
-            if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
-                data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
+        try:
+            rescale_every_n_layers = self.hparams["rescale_every"]
+            if rescale_every_n_layers > 0:
+                if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
+                    data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
+        except KeyError:
+            pass
+
+        # concat time_mix_lerp weights to reduce some cpu overhead
+        # also reduces the number of tensors in the model
+        if bid is not None and "time_mix_lerp" in new_name and "time_mix_lerp_x" not in new_name:
+            try:
+                self.lerp_weights[bid][new_name] = data_torch
+            except KeyError:
+                self.lerp_weights[bid] = {new_name: data_torch}
+            if all(f"blk.{bid}.time_mix_lerp_{i}.weight" in self.lerp_weights[bid].keys() for i in ["w", "k", "v", "r", "g"]):
+                new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
+                data = torch.stack([self.lerp_weights[bid][f"blk.{bid}.time_mix_lerp_{i}.weight"].unsqueeze(0) for i in ["w", "k", "v", "r", "g"]], dim=0).unsqueeze(1)
+                yield (new_name, data)
+            return
 
         yield (new_name, data_torch)
 
 
+@Model.register("RWKV6Qwen2ForCausalLM")
+class RWKV6Qwen2Model(Rwkv6Model):
+    model_arch = gguf.MODEL_ARCH.RWKV6QWEN2
+
+    def set_vocab(self):
+        try:
+            self._set_vocab_sentencepiece()
+        except FileNotFoundError:
+            self._set_vocab_gpt2()
+
+    def set_gguf_parameters(self):
+        block_count = self.hparams["num_hidden_layers"]
+        num_attention_heads = self.hparams["num_attention_heads"]
+        num_key_value_heads = self.hparams["num_key_value_heads"]
+        hidden_size = self.hparams["hidden_size"]
+        head_size = hidden_size // num_attention_heads
+        rms_norm_eps = self.hparams["rms_norm_eps"]
+        intermediate_size = self.hparams["intermediate_size"]
+        time_mix_extra_dim = 64 if hidden_size >= 4096 else 32
+        time_decay_extra_dim = 128 if hidden_size >= 4096 else 64
+
+        # RWKV isn't context limited
+        self.gguf_writer.add_context_length(1048576)
+        self.gguf_writer.add_embedding_length(hidden_size)
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_wkv_head_size(head_size)
+        self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
+        self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
+        self.gguf_writer.add_feed_forward_length(intermediate_size)
+        self.gguf_writer.add_file_type(self.ftype)
+
+        # special parameters for time_mixing in RWKV6QWEN2
+        self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
+        self.gguf_writer.add_token_shift_count(1)
+        # RWKV6QWEN2 use grouped key/value like GQA
+        self.gguf_writer.add_head_count_kv(num_key_value_heads)
+
+        # required by llama.cpp, unused
+        self.gguf_writer.add_head_count(0)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        for new_name, data in super().modify_tensors(data_torch, name, bid):
+            if "time_mix_w1" in new_name or "time_mix_w2" in new_name:
+                data = data.view(5, -1, data.shape[-1])
+                # rwkv6qwen2 has a different order of rkvwg instead of the original wkvrg
+                # permute them here to avoid code changes
+                data = torch.stack([data[3], data[1], data[2], data[0], data[4]], dim=0).view(-1, data.shape[-1])
+                if "w2" in new_name:
+                    data = data.view(5, -1, data.shape[-1])
+                yield (new_name, data)
+                continue
+            yield (new_name, data)
+
+
 @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
 class MambaModel(Model):
     model_arch = gguf.MODEL_ARCH.MAMBA
index 8630d92c5c6a496150fbf87225dc5a5b0f16a641..8f8cb9e1aa1401c536db42620d6809be5b55b7c8 100644 (file)
@@ -501,6 +501,7 @@ extern "C" {
         GGML_OP_GET_REL_POS,
         GGML_OP_ADD_REL_POS,
         GGML_OP_RWKV_WKV6,
+        GGML_OP_GATED_LINEAR_ATTN,
 
         GGML_OP_UNARY,
 
@@ -1859,6 +1860,15 @@ extern "C" {
             struct ggml_tensor  * td,
             struct ggml_tensor  * state);
 
+    GGML_API struct ggml_tensor * ggml_gated_linear_attn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * g,
+            struct ggml_tensor  * state,
+            float scale);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
index b7fefb9ddfd894be76566e116f2a9cd8ede1047f..2966ff7682de2c001d372842e787158369907eb8 100644 (file)
@@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
 static void ggml_compute_forward_rwkv_wkv6_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    const int64_t T = dst->src[1]->ne[3];
+    const int64_t T = dst->src[1]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t HEADS = dst->src[1]->ne[2];
+    const int64_t HEADS = dst->src[1]->ne[1];
     const int64_t n_seqs = dst->src[5]->ne[1];
     const int64_t head_size = C / HEADS;
 
@@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6(
     }
 }
 
+// ggml_compute_forward_gla
+
+static void ggml_compute_forward_gla_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[1];
+    const int64_t n_seqs = dst->src[4]->ne[1];
+    const int64_t head_size = C / HEADS;
+    const float scale = ggml_get_op_params_f32(dst, 0);
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * k = (float *) dst->src[0]->data;
+    float * v = (float *) dst->src[1]->data;
+    float * q = (float *) dst->src[2]->data;
+    float * g = (float *) dst->src[3]->data;
+
+    size_t t_stride = HEADS * head_size; // Same to C
+
+    size_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    size_t h_stride_2d = head_size * head_size;
+
+    if (ith == 0) {
+        memset(dst_data, 0, T * C * sizeof(float));
+    }
+    ggml_barrier(params->threadpool);
+
+
+    #if defined(__AVX__) && !defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x8
+        #define GGML_F32X_SET1 GGML_F32x8_SET1
+        #define GGML_F32X_LOAD GGML_F32x8_LOAD
+        #define GGML_F32X_STORE GGML_F32x8_STORE
+        #define GGML_F32X_MUL GGML_F32x8_MUL
+        #define GGML_F32X_FMA GGML_F32x8_FMA
+        #define GLA_VECTOR_SIZE 8
+    #elif defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x16
+        #define GGML_F32X_SET1 GGML_F32x16_SET1
+        #define GGML_F32X_LOAD GGML_F32x16_LOAD
+        #define GGML_F32X_STORE GGML_F32x16_STORE
+        #define GGML_F32X_MUL GGML_F32x16_MUL
+        #define GGML_F32X_FMA GGML_F32x16_FMA
+        #define GLA_VECTOR_SIZE 16
+    #elif defined(__ARM_NEON) && defined(__aarch64__)
+        #define GGML_F32X GGML_F32x4
+        #define GGML_F32X_SET1 GGML_F32x4_SET1
+        #define GGML_F32X_LOAD GGML_F32x4_LOAD
+        #define GGML_F32X_STORE GGML_F32x4_STORE
+        #define GGML_F32X_MUL GGML_F32x4_MUL
+        #define GGML_F32X_FMA GGML_F32x4_FMA
+        #define GLA_VECTOR_SIZE 4
+    #endif
+
+    #ifdef GLA_VECTOR_SIZE
+        const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
+
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    // Broadcast scalar values to vectors
+                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
+                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
+
+                    for (int64_t j = 0; j < vec_count; j++) {
+                        size_t base_j = j * GLA_VECTOR_SIZE;
+                        size_t t_h_j_offset = t_h_offset + base_j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+                        // Load x elements at once
+                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+                        // Compute kv = v * k
+                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+                        // Compute temp = prev_state * g + kv
+                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
+
+                        // Update dst: dst += temp * q
+                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
+                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+                        // Update state
+                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
+                    }
+
+                    // Handle remaining elements, this will not be used.
+                    for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val + prev_state_val * g_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+
+    #else
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = prev_state_val * g_val + kv_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+    #endif
+}
+
+
+static void ggml_compute_forward_gla(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gla_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rwkv_wkv6(params, tensor);
             } break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            {
+                ggml_compute_forward_gla(params, tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
index 0b06be729864e6de01b8d97a5da858f5017a4b3f..8476ee1bca50c8c0cd8645f0238e7a4ef8f691d5 100644 (file)
@@ -37,6 +37,7 @@
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv6.cuh"
+#include "ggml-cuda/gla.cuh"
 
 #include <algorithm>
 #include <array>
@@ -2167,6 +2168,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_RWKV_WKV6:
             ggml_cuda_op_rwkv_wkv6(ctx, dst);
             break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            ggml_cuda_op_gated_linear_attn(ctx, dst);
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
             break;
@@ -3011,6 +3015,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
             return true;
         case GGML_OP_FLASH_ATTN_EXT: {
 #ifndef FLASH_ATTN_AVAILABLE
diff --git a/ggml/src/ggml-cuda/gla.cu b/ggml/src/ggml-cuda/gla.cu
new file mode 100644 (file)
index 0000000..f7d615a
--- /dev/null
@@ -0,0 +1,93 @@
+#include "common.cuh"
+#include "gla.cuh"
+
+template<int HEAD_SIZE>
+static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
+     const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = HEAD_SIZE;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _k[head_size], _r[head_size], _td[head_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        __syncthreads();
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4 & k = (float4 &)(_k[j]);
+            const float4 & r = (float4 &)(_r[j]);
+            const float4 & td = (float4 &)(_td[j]);
+            float4 & s = (float4 &)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            s.x = s.x * td.x + kv.x;
+            s.y = s.y * td.y + kv.y;
+            s.z = s.z * td.z + kv.z;
+            s.w = s.w * td.w + kv.w;
+
+            y += r.x * s.x;
+            y += r.y * s.y;
+            y += r.z * s.z;
+            y += r.w * s.w;
+        }
+        dst[t] = y * scale;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * k_d  = (const float *)dst->src[0]->data;
+    const float * v_d  = (const float *)dst->src[1]->data;
+    const float * r_d  = (const float *)dst->src[2]->data;
+    const float * td_d = (const float *)dst->src[3]->data;
+    const float * s_d  = (const float *)dst->src[4]->data;
+
+    const int64_t B = dst->src[4]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    float scale;
+    memcpy(&scale, (float*)dst->op_params, sizeof(float));
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == 64 || C / H == 128);
+
+
+    if (C / H == 64) {
+        gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+    } else {
+        gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+    }
+}
diff --git a/ggml/src/ggml-cuda/gla.cuh b/ggml/src/ggml-cuda/gla.cuh
new file mode 100644 (file)
index 0000000..2c82ad7
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 42578341a386b5ff5178ba465b5ee99f48d175f9..bbdafbee5818ba666f282c7180f33acdabb2df87 100644 (file)
@@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     const float * s_d  = (const float *)dst->src[5]->data;
 
     const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[3];
+    const int64_t T = dst->src[0]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[2];
+    const int64_t H = dst->src[0]->ne[1];
 
     float * dst_d = (float *)dst->data;
 
index 4fed18c2aafb14ca6c3ba5d17245ea003af30a08..b54c20964ed5dfc0bb820a8af599bb2c2d06953a 100644 (file)
@@ -109,9 +109,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
     float* dst_d = (float*)dst->data;
 
     const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[3];
+    const int64_t T = dst->src[0]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[2];
+    const int64_t H = dst->src[0]->ne[1];
 
     GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
     GGML_ASSERT(C % H == 0);
index 0774524242a6b3cfa4b00c271092b0ca921226e1..1b917468239ddd49a385e3766444b87d0ffa5451 100644 (file)
@@ -5633,9 +5633,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
 }
 
 static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
-    const size_t seq_length = dst->src[0]->ne[3];
+    const size_t seq_length = dst->src[0]->ne[2];
     const size_t n_embed = dst->ne[0];
-    const size_t n_heads = dst->src[0]->ne[2];
+    const size_t n_heads = dst->src[0]->ne[1];
     const size_t n_seqs = dst->src[5]->ne[1];
 
     ggml_vk_op_f32_rwkv6(
index 90abc6ad45233aab20ada03b8e7a9c6736e506c1..da5b817e1563785bbc9ffa48e559019c5d8b5c63 100644 (file)
@@ -968,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GET_REL_POS",
     "ADD_REL_POS",
     "RWKV_WKV6",
+    "GATED_LINEAR_ATTN",
 
     "UNARY",
 
@@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1064,6 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "get_rel_pos(x)",
     "add_rel_pos(x)",
     "rwkv_wkv6(k, v, r, tf, td, s)",
+    "gated_linear_attn(k, v, q, gate, s)",
 
     "unary(x)",
 
@@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4629,15 +4631,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
     GGML_ASSERT(ggml_is_contiguous(state));
 
     const int64_t S = k->ne[0];
-    const int64_t H = k->ne[2];
-    const int64_t n_tokens = k->ne[3];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
     const int64_t n_seqs = state->ne[1];
     {
-        GGML_ASSERT(k->ne[1] == 1);
-        GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
-        GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
-        // TODO: RWKV v4 and v5
-        GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
+        GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
         GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
     }
 
@@ -4656,6 +4656,49 @@ struct ggml_tensor * ggml_rwkv_wkv6(
     return result;
 }
 
+// ggml_gated_linear_attn
+
+struct ggml_tensor * ggml_gated_linear_attn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * state,
+        float scale) {
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S = k->ne[0];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
+    const int64_t n_seqs = state->ne[1];
+    {
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
+        GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
+        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    }
+
+    // concat output and new_state
+    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    ggml_set_op_params_f32(result, 0, scale);
+
+    result->op     = GGML_OP_GATED_LINEAR_ATTN;
+    result->src[0] = k;
+    result->src[1] = v;
+    result->src[2] = q;
+    result->src[3] = g;
+    result->src[4] = state;
+
+    return result;
+}
+
 // ggml_unary
 
 static struct ggml_tensor * ggml_unary_impl(
index cf05bf47ece0830cbf699d8580df2015b5208522..56aa9288dad40cbdb25fe8022381cff8022eb08a 100644 (file)
@@ -115,6 +115,7 @@ class Keys:
         TIME_DECAY_EXTRA_DIM              = "{arch}.time_decay_extra_dim"
         RESIDUAL_SCALE                    = "{arch}.residual_scale"
         EMBEDDING_SCALE                   = "{arch}.embedding_scale"
+        TOKEN_SHIFT_COUNT                 = "{arch}.token_shift_count"
 
     class Attention:
         HEAD_COUNT        = "{arch}.attention.head_count"
@@ -255,6 +256,7 @@ class MODEL_ARCH(IntEnum):
     GEMMA2           = auto()
     STARCODER2       = auto()
     RWKV6            = auto()
+    RWKV6QWEN2       = auto()
     MAMBA            = auto()
     XVERSE           = auto()
     COMMAND_R        = auto()
@@ -334,6 +336,7 @@ class MODEL_TENSOR(IntEnum):
     TIME_MIX_LERP_V      = auto()
     TIME_MIX_LERP_R      = auto()
     TIME_MIX_LERP_G      = auto()
+    TIME_MIX_LERP_FUSED  = auto()
     TIME_MIX_LERP_W      = auto()
     TIME_MIX_FIRST       = auto()
     TIME_MIX_DECAY       = auto()
@@ -440,6 +443,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.GEMMA2:           "gemma2",
     MODEL_ARCH.STARCODER2:       "starcoder2",
     MODEL_ARCH.RWKV6:            "rwkv6",
+    MODEL_ARCH.RWKV6QWEN2:       "rwkv6qwen2",
     MODEL_ARCH.MAMBA:            "mamba",
     MODEL_ARCH.XVERSE:           "xverse",
     MODEL_ARCH.COMMAND_R:        "command-r",
@@ -519,6 +523,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.TIME_MIX_LERP_V:           "blk.{bid}.time_mix_lerp_v",
     MODEL_TENSOR.TIME_MIX_LERP_R:           "blk.{bid}.time_mix_lerp_r",
     MODEL_TENSOR.TIME_MIX_LERP_G:           "blk.{bid}.time_mix_lerp_g",
+    MODEL_TENSOR.TIME_MIX_LERP_FUSED:       "blk.{bid}.time_mix_lerp_fused",
     MODEL_TENSOR.TIME_MIX_LERP_W:           "blk.{bid}.time_mix_lerp_w",
     MODEL_TENSOR.TIME_MIX_FIRST:            "blk.{bid}.time_mix_first",
     MODEL_TENSOR.TIME_MIX_DECAY:            "blk.{bid}.time_mix_decay",
@@ -1103,6 +1108,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.TIME_MIX_LERP_R,
         MODEL_TENSOR.TIME_MIX_LERP_G,
         MODEL_TENSOR.TIME_MIX_LERP_W,
+        MODEL_TENSOR.TIME_MIX_LERP_FUSED,
         MODEL_TENSOR.TIME_MIX_FIRST,
         MODEL_TENSOR.TIME_MIX_DECAY,
         MODEL_TENSOR.TIME_MIX_DECAY_W1,
@@ -1119,6 +1125,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE,
         MODEL_TENSOR.CHANNEL_MIX_VALUE,
     ],
+    MODEL_ARCH.RWKV6QWEN2: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.TIME_MIX_W1,
+        MODEL_TENSOR.TIME_MIX_W2,
+        MODEL_TENSOR.TIME_MIX_LERP_X,
+        MODEL_TENSOR.TIME_MIX_LERP_K,
+        MODEL_TENSOR.TIME_MIX_LERP_V,
+        MODEL_TENSOR.TIME_MIX_LERP_R,
+        MODEL_TENSOR.TIME_MIX_LERP_G,
+        MODEL_TENSOR.TIME_MIX_LERP_W,
+        MODEL_TENSOR.TIME_MIX_LERP_FUSED,
+        MODEL_TENSOR.TIME_MIX_FIRST,
+        MODEL_TENSOR.TIME_MIX_DECAY,
+        MODEL_TENSOR.TIME_MIX_DECAY_W1,
+        MODEL_TENSOR.TIME_MIX_DECAY_W2,
+        MODEL_TENSOR.TIME_MIX_KEY,
+        MODEL_TENSOR.TIME_MIX_VALUE,
+        MODEL_TENSOR.TIME_MIX_RECEPTANCE,
+        MODEL_TENSOR.TIME_MIX_GATE,
+        MODEL_TENSOR.TIME_MIX_LN,
+        MODEL_TENSOR.TIME_MIX_OUTPUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.MAMBA: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 4a0a65e3cc33ec2efa5db1862fd88aa77f41877f..bf851c92ca5483d961810e09cb22fdee211c1d6d 100644 (file)
@@ -743,6 +743,9 @@ class GGUFWriter:
     def add_wkv_head_size(self, size: int) -> None:
         self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
 
+    def add_token_shift_count(self, count: int) -> None:
+        self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
+
     def add_layer_norm_eps(self, value: float) -> None:
         self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
 
index 7616c468a53015f9a6f6d460ac447cf6e61915a9..617791e240b6053055209db9790ad365ad34f6d0 100644 (file)
@@ -13,7 +13,7 @@ class TensorNameMap:
             "transformer.wte",                           # gpt2 gpt-j mpt refact qwen dbrx jais exaone
             "transformer.word_embeddings",               # falcon
             "word_embeddings",                           # bloom
-            "model.embed_tokens",                        # llama-hf nemotron olmoe olmo2
+            "model.embed_tokens",                        # llama-hf nemotron olmoe olmo2 rwkv6qwen2
             "tok_embeddings",                            # llama-pth
             "embeddings.word_embeddings",                # bert nomic-bert
             "language_model.embedding.word_embeddings",  # persimmon
@@ -464,34 +464,42 @@ class TensorNameMap:
 
         MODEL_TENSOR.TIME_MIX_W1: (
             "rwkv.blocks.{bid}.attention.time_maa_w1",  # rwkv v6
+            "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_W2: (
             "rwkv.blocks.{bid}.attention.time_maa_w2",  # rwkv v6
+            "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_X: (
             "rwkv.blocks.{bid}.attention.time_maa_x",   # rwkv v6
+            "model.layers.{bid}.self_attn.time_maa_x",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_K: (
             "rwkv.blocks.{bid}.attention.time_maa_k",   # rwkv v6
+            "model.layers.{bid}.self_attn.time_maa_k",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_V: (
             "rwkv.blocks.{bid}.attention.time_maa_v",   # rwkv v6
+            "model.layers.{bid}.self_attn.time_maa_v",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_R: (
             "rwkv.blocks.{bid}.attention.time_maa_r",   # rwkv v6
+            "model.layers.{bid}.self_attn.time_maa_r",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_G: (
             "rwkv.blocks.{bid}.attention.time_maa_g",   # rwkv v6
+            "model.layers.{bid}.self_attn.time_maa_g",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LERP_W: (
             "rwkv.blocks.{bid}.attention.time_maa_w",   # rwkv v6
+            "model.layers.{bid}.self_attn.time_maa_w",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_FIRST: (
@@ -500,30 +508,37 @@ class TensorNameMap:
 
         MODEL_TENSOR.TIME_MIX_DECAY: (
             "rwkv.blocks.{bid}.attention.time_decay",   # rwkv v6
+            "model.layers.{bid}.self_attn.time_decay",  # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_DECAY_W1: (
             "rwkv.blocks.{bid}.attention.time_decay_w1",  # rwkv v6
+            "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_DECAY_W2: (
             "rwkv.blocks.{bid}.attention.time_decay_w2",  # rwkv v6
+            "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_KEY: (
-            "rwkv.blocks.{bid}.attention.key", # rwkv
+            "rwkv.blocks.{bid}.attention.key",     # rwkv
+            "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_VALUE: (
-            "rwkv.blocks.{bid}.attention.value", # rwkv
+            "rwkv.blocks.{bid}.attention.value",   # rwkv
+            "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
             "rwkv.blocks.{bid}.attention.receptance", # rwkv
+            "model.layers.{bid}.self_attn.q_proj",    # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_GATE: (
-            "rwkv.blocks.{bid}.attention.gate", # rwkv
+            "rwkv.blocks.{bid}.attention.gate",  # rwkv
+            "model.layers.{bid}.self_attn.gate", # rwkv6qwen2
         ),
 
         MODEL_TENSOR.TIME_MIX_LN: (
@@ -531,7 +546,8 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.TIME_MIX_OUTPUT: (
-            "rwkv.blocks.{bid}.attention.output", # rwkv
+            "rwkv.blocks.{bid}.attention.output",  # rwkv
+            "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
         ),
 
         MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
index eef66ed311d7a202e9ebbff78c59cd2ffda7d55a..7300bd26a907f87f8d1a85f4f0ea327de35ece6b 100644 (file)
@@ -57,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_NEMOTRON,         "nemotron"         },
     { LLM_ARCH_EXAONE,           "exaone"           },
     { LLM_ARCH_RWKV6,            "rwkv6"            },
+    { LLM_ARCH_RWKV6QWEN2,       "rwkv6qwen2"       },
     { LLM_ARCH_GRANITE,          "granite"          },
     { LLM_ARCH_GRANITE_MOE,      "granitemoe"       },
     { LLM_ARCH_CHAMELEON,        "chameleon"        },
@@ -106,6 +107,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TIME_DECAY_EXTRA_DIM,              "%s.time_decay_extra_dim"              },
     { LLM_KV_RESIDUAL_SCALE,                    "%s.residual_scale"                    },
     { LLM_KV_EMBEDDING_SCALE,                   "%s.embedding_scale"                   },
+    { LLM_KV_TOKEN_SHIFT_COUNT,                 "%s.token_shift_count"                 },
 
     { LLM_KV_ATTENTION_HEAD_COUNT,             "%s.attention.head_count"             },
     { LLM_KV_ATTENTION_HEAD_COUNT_KV,          "%s.attention.head_count_kv"          },
@@ -1166,6 +1168,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_TIME_MIX_LERP_V,           "blk.%d.time_mix_lerp_v" },
             { LLM_TENSOR_TIME_MIX_LERP_R,           "blk.%d.time_mix_lerp_r" },
             { LLM_TENSOR_TIME_MIX_LERP_G,           "blk.%d.time_mix_lerp_g" },
+            { LLM_TENSOR_TIME_MIX_LERP_FUSED,       "blk.%d.time_mix_lerp_fused" },
             { LLM_TENSOR_TIME_MIX_FIRST,            "blk.%d.time_mix_first" },
             { LLM_TENSOR_TIME_MIX_DECAY,            "blk.%d.time_mix_decay" },
             { LLM_TENSOR_TIME_MIX_DECAY_W1,         "blk.%d.time_mix_decay_w1" },
@@ -1183,6 +1186,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,    "blk.%d.channel_mix_receptance" },
         },
     },
+    {
+        LLM_ARCH_RWKV6QWEN2,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,                "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,               "output_norm" },
+            { LLM_TENSOR_OUTPUT,                    "output" },
+            { LLM_TENSOR_ATTN_NORM,                 "blk.%d.attn_norm" },
+            { LLM_TENSOR_TIME_MIX_W1,               "blk.%d.time_mix_w1" },
+            { LLM_TENSOR_TIME_MIX_W2,               "blk.%d.time_mix_w2" },
+            { LLM_TENSOR_TIME_MIX_LERP_X,           "blk.%d.time_mix_lerp_x" },
+            { LLM_TENSOR_TIME_MIX_LERP_FUSED,       "blk.%d.time_mix_lerp_fused" },
+            { LLM_TENSOR_TIME_MIX_FIRST,            "blk.%d.time_mix_first" },
+            { LLM_TENSOR_TIME_MIX_DECAY,            "blk.%d.time_mix_decay" },
+            { LLM_TENSOR_TIME_MIX_DECAY_W1,         "blk.%d.time_mix_decay_w1" },
+            { LLM_TENSOR_TIME_MIX_DECAY_W2,         "blk.%d.time_mix_decay_w2" },
+            { LLM_TENSOR_TIME_MIX_KEY,              "blk.%d.time_mix_key" },
+            { LLM_TENSOR_TIME_MIX_VALUE,            "blk.%d.time_mix_value" },
+            { LLM_TENSOR_TIME_MIX_RECEPTANCE,       "blk.%d.time_mix_receptance" },
+            { LLM_TENSOR_TIME_MIX_GATE,             "blk.%d.time_mix_gate" },
+            { LLM_TENSOR_TIME_MIX_OUTPUT,           "blk.%d.time_mix_output" },
+            { LLM_TENSOR_FFN_NORM,                  "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,                  "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,                  "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,                    "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_GRANITE,
         {
@@ -1365,6 +1394,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_TIME_MIX_LERP_V,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     {LLM_TENSOR_TIME_MIX_LERP_R,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     {LLM_TENSOR_TIME_MIX_LERP_G,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
+    {LLM_TENSOR_TIME_MIX_LERP_FUSED,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     {LLM_TENSOR_TIME_MIX_DECAY,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     {LLM_TENSOR_TIME_MIX_FIRST,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
     {LLM_TENSOR_ATTN_NORM,                  {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
index 2e5f97b771d0ef0ace44f5f0b0d36c7ade35a632..79909f03fc522cf7ef5af9b32d9dc08313b54305 100644 (file)
@@ -61,6 +61,7 @@ enum llm_arch {
     LLM_ARCH_NEMOTRON,
     LLM_ARCH_EXAONE,
     LLM_ARCH_RWKV6,
+    LLM_ARCH_RWKV6QWEN2,
     LLM_ARCH_GRANITE,
     LLM_ARCH_GRANITE_MOE,
     LLM_ARCH_CHAMELEON,
@@ -110,6 +111,7 @@ enum llm_kv {
     LLM_KV_TIME_DECAY_EXTRA_DIM,
     LLM_KV_RESIDUAL_SCALE,
     LLM_KV_EMBEDDING_SCALE,
+    LLM_KV_TOKEN_SHIFT_COUNT,
 
     LLM_KV_ATTENTION_HEAD_COUNT,
     LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -253,6 +255,7 @@ enum llm_tensor {
     LLM_TENSOR_TIME_MIX_LERP_V,
     LLM_TENSOR_TIME_MIX_LERP_R,
     LLM_TENSOR_TIME_MIX_LERP_G,
+    LLM_TENSOR_TIME_MIX_LERP_FUSED,
     LLM_TENSOR_TIME_MIX_FIRST,
     LLM_TENSOR_TIME_MIX_DECAY,
     LLM_TENSOR_TIME_MIX_DECAY_W1,
index c40534696b65f7b4c9238bc63eb961039ca81e55..ea87b2953d9ddb7af448bc84b4657d8882d1ba53 100644 (file)
@@ -52,7 +52,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
 uint32_t llama_hparams::n_embd_k_s() const {
     if (wkv_head_size != 0) {
         // for RWKV models
-        return 2 * n_embd;
+        return token_shift_count * n_embd;
     }
 
     // TODO: maybe support other convolution strides than 1
index a29f20ec49665b758c31f9593ef486acfa5e5f7b..3542bef499eace9d66e043880a8c144b7f7ff998 100644 (file)
@@ -76,6 +76,7 @@ struct llama_hparams {
     uint32_t time_mix_extra_dim     = 0;
     uint32_t time_decay_extra_dim   = 0;
     uint32_t wkv_head_size          = 0;
+    uint32_t token_shift_count      = 2;
 
     float    rope_attn_factor = 1.0f;
     float    rope_freq_base_train;
index 7260cb155261b5282ba615b44b0919f7669b37dd..c056204b0995efbebe7799c2149f2321ceb11f08 100644 (file)
@@ -1054,12 +1054,15 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
                 }
             } break;
         case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
             {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false);
                 ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
                 ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
                 ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
                 ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
+                ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
 
                 switch (hparams.n_layer) {
                     case 24: model.type = e_model::MODEL_1_6B; break;
@@ -1070,6 +1073,7 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
                             default: model.type = e_model::MODEL_UNKNOWN;
                         } break;
                     case 61: model.type = e_model::MODEL_14B; break;
+                    case 64: model.type = e_model::MODEL_32B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
@@ -2064,6 +2068,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_T5ENCODER:
         case LLM_ARCH_JAIS:
         case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
         case LLM_ARCH_WAVTOKENIZER_DEC:
             return LLAMA_ROPE_TYPE_NONE;
 
@@ -2208,6 +2213,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
     switch (model->arch) {
         case LLM_ARCH_MAMBA:  return true;
         case LLM_ARCH_RWKV6:  return true;
+        case LLM_ARCH_RWKV6QWEN2: return true;
         default:              return false;
     }
 }
index 424cb0f521943af787e0cc2274e96dbb094eee12..565d2dbdf6ff15df3fa3ac263b4ead4b83c4bc01 100644 (file)
@@ -241,15 +241,19 @@ struct llama_layer {
     struct ggml_tensor * time_mix_lerp_v     = nullptr;
     struct ggml_tensor * time_mix_lerp_r     = nullptr;
     struct ggml_tensor * time_mix_lerp_g     = nullptr;
-
-    struct ggml_tensor * time_mix_first      = nullptr;
-    struct ggml_tensor * time_mix_decay      = nullptr;
-    struct ggml_tensor * time_mix_decay_w1   = nullptr;
-    struct ggml_tensor * time_mix_decay_w2   = nullptr;
-    struct ggml_tensor * time_mix_key        = nullptr;
-    struct ggml_tensor * time_mix_value      = nullptr;
-    struct ggml_tensor * time_mix_receptance = nullptr;
-    struct ggml_tensor * time_mix_gate       = nullptr;
+    struct ggml_tensor * time_mix_lerp_fused = nullptr;
+
+    struct ggml_tensor * time_mix_first        = nullptr;
+    struct ggml_tensor * time_mix_decay        = nullptr;
+    struct ggml_tensor * time_mix_decay_w1     = nullptr;
+    struct ggml_tensor * time_mix_decay_w2     = nullptr;
+    struct ggml_tensor * time_mix_key          = nullptr;
+    struct ggml_tensor * time_mix_key_b        = nullptr;
+    struct ggml_tensor * time_mix_value        = nullptr;
+    struct ggml_tensor * time_mix_value_b      = nullptr;
+    struct ggml_tensor * time_mix_receptance   = nullptr;
+    struct ggml_tensor * time_mix_receptance_b = nullptr;
+    struct ggml_tensor * time_mix_gate         = nullptr;
 
     struct ggml_tensor * time_mix_ln     = nullptr;
     struct ggml_tensor * time_mix_ln_b   = nullptr;
index 466e7bc61b55969758f92ac6a3581167fde2f508..a45044f30625476c00efb924ea7e839d0b5a459a 100644 (file)
@@ -620,7 +620,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
 
     qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
 
-    // sanity checks
+    // sanity checks for models that have attention layers
+    if (qs.n_attention_wv != 0)
     {
         const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
         // attention layers have a non-zero number of kv heads
@@ -758,6 +759,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         quantize &= name.find("time_mix_w2.weight") == std::string::npos;
         quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
         quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
+        quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
 
         // do not quantize relative position bias (T5)
         quantize &= name.find("attn_rel_b.weight") == std::string::npos;
index ae375bcd3c8b1fb914da3cdc9187340853831367..a364861d3c8039f2536227ea81e430404221c44f 100644 (file)
@@ -134,11 +134,11 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
                 const int64_t H = 123;
                 const int64_t n_tokens = 123;
                 const int64_t n_seqs = 123;
-                ggml_tensor  * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
-                ggml_tensor  * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
-                ggml_tensor  * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
+                ggml_tensor  * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
+                ggml_tensor  * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
+                ggml_tensor  * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
                 ggml_tensor  * tf = w;
-                ggml_tensor  * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
+                ggml_tensor  * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
                 ggml_tensor  * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
                 op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
             } break;
@@ -2186,11 +2186,13 @@ static bool llm_load_tensors(
                         layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
 
                         layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
 
                         layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
                         layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
@@ -2214,6 +2216,59 @@ static bool llm_load_tensors(
                     }
 
                 } break;
+            case LLM_ARCH_RWKV6QWEN2:
+                {
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+                    const int head_size = hparams.wkv_head_size;
+                    const int attn_hidden_size = n_embd;
+                    const int n_head_kv = hparams.n_head_kv();
+                    int attn_key_value_size;
+                    if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) {
+                        attn_key_value_size = attn_hidden_size;
+                    } else {
+                        attn_key_value_size = n_head_kv * head_size;
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
+                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
+
+                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
+
+                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
+                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
+                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
+                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0);
+                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0);
+                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        // optional bias tensors
+                        layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
             case LLM_ARCH_CHAMELEON:
                 {
                     model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -3337,16 +3392,20 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         const struct llama_layer * layer,
         struct ggml_tensor * cur,
         struct ggml_tensor * x_prev,
-        struct ggml_tensor ** wkv_state) {
+        struct ggml_tensor ** wkv_state,
+        size_t wkv_head_size,
+        size_t head_count_kv) {
     size_t n_embd       = cur->ne[0];
     size_t n_seq_tokens = cur->ne[1];
     size_t n_seqs       = cur->ne[2];
 
-    size_t head_size  = layer->time_mix_first->ne[0];
-    size_t head_count = layer->time_mix_first->ne[1];
+    size_t head_size  = wkv_head_size;
+    size_t head_count = n_embd / head_size;
 
     size_t n_tokens = n_seqs * n_seq_tokens;
 
+    bool is_qrwkv = layer->time_mix_first == nullptr;
+
     struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
 
     sx  = ggml_reshape_2d(ctx, sx,  n_embd, n_tokens);
@@ -3375,69 +3434,64 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         xxx
     );
 
-    struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
-    struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
-    struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
-    struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
-    struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
-
-    struct ggml_tensor * xw = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mw, layer->time_mix_lerp_w),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
+    if (layer->time_mix_lerp_fused) {
+        // fusing these weights makes some performance improvement
+        sx  = ggml_reshape_3d(ctx, sx,  n_embd, 1, n_tokens);
+        cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
+        xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
+        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+    } else {
+        // for backward compatibility
+        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
 
-    struct ggml_tensor * xk = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mk, layer->time_mix_lerp_k),
-            sx
-        ),
-        cur
-    );
+        xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
+        xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
+        xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
+        xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
+        xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
+    }
 
-    struct ggml_tensor * xv = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mv, layer->time_mix_lerp_v),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
+    struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk);
+    struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv);
+    if (layer->time_mix_receptance_b) {
+        r = ggml_add(ctx, r, layer->time_mix_receptance_b);
+    }
+    if (layer->time_mix_key_b) {
+        k = ggml_add(ctx, k, layer->time_mix_key_b);
+    }
+    if (layer->time_mix_value_b) {
+        v = ggml_add(ctx, v, layer->time_mix_value_b);
+    }
 
-    struct ggml_tensor * xr = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mr, layer->time_mix_lerp_r),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg);
+    if (is_qrwkv) {
+        g = ggml_sigmoid(ctx, g);
+    } else {
+        g = ggml_silu(ctx, g);
+    }
 
-    struct ggml_tensor * xg = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mg, layer->time_mix_lerp_g),
-            sx
-        ),
-        cur
-    );
+    if (head_count_kv != head_count) {
+        GGML_ASSERT(head_count % head_count_kv == 0);
+        k = ggml_reshape_4d(ctx, k, head_size, 1, head_count_kv, n_tokens);
+        v = ggml_reshape_4d(ctx, v, head_size, 1, head_count_kv, n_tokens);
+        struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens);
+        k = ggml_repeat(ctx, k, tmp);
+        v = ggml_repeat(ctx, v, tmp);
+    }
 
-    struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1,         head_count, n_tokens);
-    struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk), 1,         head_size, head_count, n_tokens);
-    struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv), head_size, 1,         head_count, n_tokens);
-    struct ggml_tensor * g = ggml_silu(
-        ctx,
-        llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
-    );
+    k = ggml_reshape_3d(ctx, k, head_size, head_count, n_tokens);
+    v = ggml_reshape_3d(ctx, v, head_size, head_count, n_tokens);
+    r = ggml_reshape_3d(ctx, r, head_size, head_count, n_tokens);
 
     struct ggml_tensor * w = ggml_mul_mat(
         ctx,
@@ -3448,25 +3502,35 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         )
     );
 
-    w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
+    w = ggml_add(ctx, w, layer->time_mix_decay);
     w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
-    w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
+    w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens);
 
-    k = ggml_transpose(ctx, k);
-    v = ggml_transpose(ctx, v);
-    r = ggml_transpose(ctx, r);
+    if (is_qrwkv) {
+        // k = k * (1 - w)
+        k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
+    }
 
-    struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    struct ggml_tensor * wkv_output;
+    if (!layer->time_mix_first) {
+        wkv_output = ggml_gated_linear_attn(ctx, k, v, r, w, *wkv_state, pow(head_size, -0.5f));
+    } else {
+        wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    }
     cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
     *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
 
-    // group norm with head_count groups
-    cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
-    cur = ggml_norm(ctx, cur, 64e-5f);
+    if (!is_qrwkv) {
+        // group norm with head_count groups
+        cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
+        cur = ggml_norm(ctx, cur, 64e-5f);
 
-    // Convert back to regular vectors.
-    cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
-    cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+        // Convert back to regular vectors.
+        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+        cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+    } else {
+        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+    }
 
     cur = ggml_mul(ctx, cur, g);
     cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
@@ -10048,7 +10112,7 @@ struct llm_build_context {
                 1
             );
 
-            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
+            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size));
             ggml_build_forward_expand(gf, cur);
             ggml_build_forward_expand(
                 gf,
@@ -10115,6 +10179,118 @@ struct llm_build_context {
         return gf;
     }
 
+    // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
+    ggml_cgraph * build_rwkv6qwen2() {
+        ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+
+        const int64_t n_seqs = ubatch.n_seqs;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+        struct ggml_tensor * state_copy = build_inp_s_copy();
+        struct ggml_tensor * state_mask = build_inp_s_mask();
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            // (ab)using the KV cache to store the states
+            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.k_l[il], state_copy, state_mask,
+                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
+            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.v_l[il], state_copy, state_mask,
+                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
+
+            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
+
+            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
+            struct ggml_tensor * x_prev = ggml_concat(
+                ctx0,
+                token_shift,
+                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
+                1
+            );
+
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_states,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self.v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+                    )
+                )
+            );
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv()));
+            ggml_build_forward_expand(gf, ffn_inp);
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_states,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self.v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+                    )
+                )
+            );
+
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // ref: https://github.com/facebookresearch/chameleon
     // based on the original build_llama() function, changes:
     //   * qk-norm
@@ -10724,6 +10900,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_rwkv6();
             } break;
+        case LLM_ARCH_RWKV6QWEN2:
+            {
+                result = llm.build_rwkv6qwen2();
+            } break;
         case LLM_ARCH_CHAMELEON:
             {
                 result = llm.build_chameleon();
index 1e892f66365e0f7318320cac756603ce532233df..3834e0f84aa72b7c95dd9d0db47de7f61ceec6a2 100644 (file)
@@ -1659,17 +1659,46 @@ struct test_rwkv_wkv6 : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         const int64_t n_tokens = n_seq_tokens * n_seqs;
-        ggml_tensor * r   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
-        ggml_tensor * k   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
-        ggml_tensor * v   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
+        ggml_tensor * r   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
         ggml_tensor * tf  = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
-        ggml_tensor * td  = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
+        ggml_tensor * td  = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
         ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
         ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
         return out;
     }
 };
 
+// GGML_OP_GATED_LINEAR_ATTN
+struct test_gla : public test_case {
+    const ggml_type type;
+
+    const int64_t head_count;
+    const int64_t head_size;
+    const int64_t n_seq_tokens;
+    const int64_t n_seqs;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
+    }
+
+    test_gla(ggml_type type = GGML_TYPE_F32,
+            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
+        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const int64_t n_tokens = n_seq_tokens * n_seqs;
+        ggml_tensor * q   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * g   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
+        ggml_tensor * out = ggml_gated_linear_attn(ctx, k, v, q, g, s, pow(head_size, -0.5));
+        return out;
+    }
+};
+
 // GGML_OP_MUL_MAT
 struct test_mul_mat : public test_case {
     const ggml_type type_a;
@@ -3626,6 +3655,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
 
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
+
     for (int i = 1; i < 9; ++i) {
         test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16,    GGML_TYPE_F32, 16,  i, 256, { 1,  1}, {1, 1}));
         test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0,   GGML_TYPE_F32, 16,  i, 256, { 1,  1}, {1, 1}));