]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
llama: add support for QRWKV6 model architecture (llama/11001)
authorMolly Sophia <redacted>
Fri, 10 Jan 2025 01:58:08 +0000 (09:58 +0800)
committerGeorgi Gerganov <redacted>
Tue, 14 Jan 2025 08:38:01 +0000 (10:38 +0200)
llama: add support for QRWKV6 model architecture (llama/11001)

* WIP: Add support for RWKV6Qwen2

Signed-off-by: Molly Sophia <redacted>
* RWKV: Some graph simplification

Signed-off-by: Molly Sophia <redacted>
* Add support for RWKV6Qwen2 with cpu and cuda GLA

Signed-off-by: Molly Sophia <redacted>
* RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead

Signed-off-by: Molly Sophia <redacted>
* Fix some typos

Signed-off-by: Molly Sophia <redacted>
* code format changes

Signed-off-by: Molly Sophia <redacted>
* Fix wkv test & add gla test

Signed-off-by: Molly Sophia <redacted>
* Fix cuda warning

Signed-off-by: Molly Sophia <redacted>
* Update README.md

Signed-off-by: Molly Sophia <redacted>
* Update ggml/src/ggml-cuda/gla.cu

Co-authored-by: Georgi Gerganov <redacted>
* Fix fused lerp weights loading with RWKV6

Signed-off-by: Molly Sophia <redacted>
* better sanity check skipping for QRWKV6 in llama-quant

thanks @compilade

Signed-off-by: Molly Sophia <redacted>
Co-authored-by: compilade <redacted>
---------

Signed-off-by: Molly Sophia <redacted>
Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: compilade <redacted>
ggml/include/ggml.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/gla.cu [new file with mode: 0644]
ggml/src/ggml-cuda/gla.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/wkv6.cu
ggml/src/ggml-sycl/wkv6.cpp
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml.c

index 8630d92c5c6a496150fbf87225dc5a5b0f16a641..8f8cb9e1aa1401c536db42620d6809be5b55b7c8 100644 (file)
@@ -501,6 +501,7 @@ extern "C" {
         GGML_OP_GET_REL_POS,
         GGML_OP_ADD_REL_POS,
         GGML_OP_RWKV_WKV6,
+        GGML_OP_GATED_LINEAR_ATTN,
 
         GGML_OP_UNARY,
 
@@ -1859,6 +1860,15 @@ extern "C" {
             struct ggml_tensor  * td,
             struct ggml_tensor  * state);
 
+    GGML_API struct ggml_tensor * ggml_gated_linear_attn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * g,
+            struct ggml_tensor  * state,
+            float scale);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
index b7fefb9ddfd894be76566e116f2a9cd8ede1047f..2966ff7682de2c001d372842e787158369907eb8 100644 (file)
@@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
 static void ggml_compute_forward_rwkv_wkv6_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    const int64_t T = dst->src[1]->ne[3];
+    const int64_t T = dst->src[1]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t HEADS = dst->src[1]->ne[2];
+    const int64_t HEADS = dst->src[1]->ne[1];
     const int64_t n_seqs = dst->src[5]->ne[1];
     const int64_t head_size = C / HEADS;
 
@@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6(
     }
 }
 
+// ggml_compute_forward_gla
+
+static void ggml_compute_forward_gla_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[1];
+    const int64_t n_seqs = dst->src[4]->ne[1];
+    const int64_t head_size = C / HEADS;
+    const float scale = ggml_get_op_params_f32(dst, 0);
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * k = (float *) dst->src[0]->data;
+    float * v = (float *) dst->src[1]->data;
+    float * q = (float *) dst->src[2]->data;
+    float * g = (float *) dst->src[3]->data;
+
+    size_t t_stride = HEADS * head_size; // Same to C
+
+    size_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    size_t h_stride_2d = head_size * head_size;
+
+    if (ith == 0) {
+        memset(dst_data, 0, T * C * sizeof(float));
+    }
+    ggml_barrier(params->threadpool);
+
+
+    #if defined(__AVX__) && !defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x8
+        #define GGML_F32X_SET1 GGML_F32x8_SET1
+        #define GGML_F32X_LOAD GGML_F32x8_LOAD
+        #define GGML_F32X_STORE GGML_F32x8_STORE
+        #define GGML_F32X_MUL GGML_F32x8_MUL
+        #define GGML_F32X_FMA GGML_F32x8_FMA
+        #define GLA_VECTOR_SIZE 8
+    #elif defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x16
+        #define GGML_F32X_SET1 GGML_F32x16_SET1
+        #define GGML_F32X_LOAD GGML_F32x16_LOAD
+        #define GGML_F32X_STORE GGML_F32x16_STORE
+        #define GGML_F32X_MUL GGML_F32x16_MUL
+        #define GGML_F32X_FMA GGML_F32x16_FMA
+        #define GLA_VECTOR_SIZE 16
+    #elif defined(__ARM_NEON) && defined(__aarch64__)
+        #define GGML_F32X GGML_F32x4
+        #define GGML_F32X_SET1 GGML_F32x4_SET1
+        #define GGML_F32X_LOAD GGML_F32x4_LOAD
+        #define GGML_F32X_STORE GGML_F32x4_STORE
+        #define GGML_F32X_MUL GGML_F32x4_MUL
+        #define GGML_F32X_FMA GGML_F32x4_FMA
+        #define GLA_VECTOR_SIZE 4
+    #endif
+
+    #ifdef GLA_VECTOR_SIZE
+        const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
+
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    // Broadcast scalar values to vectors
+                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
+                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
+
+                    for (int64_t j = 0; j < vec_count; j++) {
+                        size_t base_j = j * GLA_VECTOR_SIZE;
+                        size_t t_h_j_offset = t_h_offset + base_j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+                        // Load x elements at once
+                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+                        // Compute kv = v * k
+                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+                        // Compute temp = prev_state * g + kv
+                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
+
+                        // Update dst: dst += temp * q
+                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
+                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+                        // Update state
+                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
+                    }
+
+                    // Handle remaining elements, this will not be used.
+                    for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val + prev_state_val * g_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+
+    #else
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = prev_state_val * g_val + kv_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+    #endif
+}
+
+
+static void ggml_compute_forward_gla(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gla_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rwkv_wkv6(params, tensor);
             } break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            {
+                ggml_compute_forward_gla(params, tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
index 0b06be729864e6de01b8d97a5da858f5017a4b3f..8476ee1bca50c8c0cd8645f0238e7a4ef8f691d5 100644 (file)
@@ -37,6 +37,7 @@
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv6.cuh"
+#include "ggml-cuda/gla.cuh"
 
 #include <algorithm>
 #include <array>
@@ -2167,6 +2168,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_RWKV_WKV6:
             ggml_cuda_op_rwkv_wkv6(ctx, dst);
             break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            ggml_cuda_op_gated_linear_attn(ctx, dst);
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
             break;
@@ -3011,6 +3015,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
             return true;
         case GGML_OP_FLASH_ATTN_EXT: {
 #ifndef FLASH_ATTN_AVAILABLE
diff --git a/ggml/src/ggml-cuda/gla.cu b/ggml/src/ggml-cuda/gla.cu
new file mode 100644 (file)
index 0000000..f7d615a
--- /dev/null
@@ -0,0 +1,93 @@
+#include "common.cuh"
+#include "gla.cuh"
+
+template<int HEAD_SIZE>
+static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
+     const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = HEAD_SIZE;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _k[head_size], _r[head_size], _td[head_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        __syncthreads();
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4 & k = (float4 &)(_k[j]);
+            const float4 & r = (float4 &)(_r[j]);
+            const float4 & td = (float4 &)(_td[j]);
+            float4 & s = (float4 &)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            s.x = s.x * td.x + kv.x;
+            s.y = s.y * td.y + kv.y;
+            s.z = s.z * td.z + kv.z;
+            s.w = s.w * td.w + kv.w;
+
+            y += r.x * s.x;
+            y += r.y * s.y;
+            y += r.z * s.z;
+            y += r.w * s.w;
+        }
+        dst[t] = y * scale;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * k_d  = (const float *)dst->src[0]->data;
+    const float * v_d  = (const float *)dst->src[1]->data;
+    const float * r_d  = (const float *)dst->src[2]->data;
+    const float * td_d = (const float *)dst->src[3]->data;
+    const float * s_d  = (const float *)dst->src[4]->data;
+
+    const int64_t B = dst->src[4]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    float scale;
+    memcpy(&scale, (float*)dst->op_params, sizeof(float));
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == 64 || C / H == 128);
+
+
+    if (C / H == 64) {
+        gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+    } else {
+        gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+    }
+}
diff --git a/ggml/src/ggml-cuda/gla.cuh b/ggml/src/ggml-cuda/gla.cuh
new file mode 100644 (file)
index 0000000..2c82ad7
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 42578341a386b5ff5178ba465b5ee99f48d175f9..bbdafbee5818ba666f282c7180f33acdabb2df87 100644 (file)
@@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     const float * s_d  = (const float *)dst->src[5]->data;
 
     const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[3];
+    const int64_t T = dst->src[0]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[2];
+    const int64_t H = dst->src[0]->ne[1];
 
     float * dst_d = (float *)dst->data;
 
index 4fed18c2aafb14ca6c3ba5d17245ea003af30a08..b54c20964ed5dfc0bb820a8af599bb2c2d06953a 100644 (file)
@@ -109,9 +109,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
     float* dst_d = (float*)dst->data;
 
     const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[3];
+    const int64_t T = dst->src[0]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[2];
+    const int64_t H = dst->src[0]->ne[1];
 
     GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
     GGML_ASSERT(C % H == 0);
index 0774524242a6b3cfa4b00c271092b0ca921226e1..1b917468239ddd49a385e3766444b87d0ffa5451 100644 (file)
@@ -5633,9 +5633,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
 }
 
 static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
-    const size_t seq_length = dst->src[0]->ne[3];
+    const size_t seq_length = dst->src[0]->ne[2];
     const size_t n_embed = dst->ne[0];
-    const size_t n_heads = dst->src[0]->ne[2];
+    const size_t n_heads = dst->src[0]->ne[1];
     const size_t n_seqs = dst->src[5]->ne[1];
 
     ggml_vk_op_f32_rwkv6(
index 90abc6ad45233aab20ada03b8e7a9c6736e506c1..da5b817e1563785bbc9ffa48e559019c5d8b5c63 100644 (file)
@@ -968,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GET_REL_POS",
     "ADD_REL_POS",
     "RWKV_WKV6",
+    "GATED_LINEAR_ATTN",
 
     "UNARY",
 
@@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1064,6 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "get_rel_pos(x)",
     "add_rel_pos(x)",
     "rwkv_wkv6(k, v, r, tf, td, s)",
+    "gated_linear_attn(k, v, q, gate, s)",
 
     "unary(x)",
 
@@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4629,15 +4631,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
     GGML_ASSERT(ggml_is_contiguous(state));
 
     const int64_t S = k->ne[0];
-    const int64_t H = k->ne[2];
-    const int64_t n_tokens = k->ne[3];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
     const int64_t n_seqs = state->ne[1];
     {
-        GGML_ASSERT(k->ne[1] == 1);
-        GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
-        GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
-        // TODO: RWKV v4 and v5
-        GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
+        GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
         GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
     }
 
@@ -4656,6 +4656,49 @@ struct ggml_tensor * ggml_rwkv_wkv6(
     return result;
 }
 
+// ggml_gated_linear_attn
+
+struct ggml_tensor * ggml_gated_linear_attn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * state,
+        float scale) {
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S = k->ne[0];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
+    const int64_t n_seqs = state->ne[1];
+    {
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
+        GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
+        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    }
+
+    // concat output and new_state
+    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    ggml_set_op_params_f32(result, 0, scale);
+
+    result->op     = GGML_OP_GATED_LINEAR_ATTN;
+    result->src[0] = k;
+    result->src[1] = v;
+    result->src[2] = q;
+    result->src[3] = g;
+    result->src[4] = state;
+
+    return result;
+}
+
 // ggml_unary
 
 static struct ggml_tensor * ggml_unary_impl(