- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
+- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1)
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
#### Multimodal
gguf.MODEL_TENSOR.TIME_MIX_W2,
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
+ gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
gguf.MODEL_TENSOR.POSNET_NORM1,
gguf.MODEL_TENSOR.POSNET_NORM2,
)
# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)
+ lerp_weights: dict[int, dict[str, Tensor]] = {}
+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
new_name = self.map_tensor_name(name)
if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
data_torch = data_torch.squeeze()
- rescale_every_n_layers = self.hparams["rescale_every"]
- if rescale_every_n_layers > 0:
- if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
- data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
+ try:
+ rescale_every_n_layers = self.hparams["rescale_every"]
+ if rescale_every_n_layers > 0:
+ if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
+ data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
+ except KeyError:
+ pass
+
+ # concat time_mix_lerp weights to reduce some cpu overhead
+ # also reduces the number of tensors in the model
+ if bid is not None and "time_mix_lerp" in new_name and "time_mix_lerp_x" not in new_name:
+ try:
+ self.lerp_weights[bid][new_name] = data_torch
+ except KeyError:
+ self.lerp_weights[bid] = {new_name: data_torch}
+ if all(f"blk.{bid}.time_mix_lerp_{i}.weight" in self.lerp_weights[bid].keys() for i in ["w", "k", "v", "r", "g"]):
+ new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
+ data = torch.stack([self.lerp_weights[bid][f"blk.{bid}.time_mix_lerp_{i}.weight"].unsqueeze(0) for i in ["w", "k", "v", "r", "g"]], dim=0).unsqueeze(1)
+ yield (new_name, data)
+ return
yield (new_name, data_torch)
+@Model.register("RWKV6Qwen2ForCausalLM")
+class RWKV6Qwen2Model(Rwkv6Model):
+ model_arch = gguf.MODEL_ARCH.RWKV6QWEN2
+
+ def set_vocab(self):
+ try:
+ self._set_vocab_sentencepiece()
+ except FileNotFoundError:
+ self._set_vocab_gpt2()
+
+ def set_gguf_parameters(self):
+ block_count = self.hparams["num_hidden_layers"]
+ num_attention_heads = self.hparams["num_attention_heads"]
+ num_key_value_heads = self.hparams["num_key_value_heads"]
+ hidden_size = self.hparams["hidden_size"]
+ head_size = hidden_size // num_attention_heads
+ rms_norm_eps = self.hparams["rms_norm_eps"]
+ intermediate_size = self.hparams["intermediate_size"]
+ time_mix_extra_dim = 64 if hidden_size >= 4096 else 32
+ time_decay_extra_dim = 128 if hidden_size >= 4096 else 64
+
+ # RWKV isn't context limited
+ self.gguf_writer.add_context_length(1048576)
+ self.gguf_writer.add_embedding_length(hidden_size)
+ self.gguf_writer.add_block_count(block_count)
+ self.gguf_writer.add_wkv_head_size(head_size)
+ self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
+ self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
+ self.gguf_writer.add_feed_forward_length(intermediate_size)
+ self.gguf_writer.add_file_type(self.ftype)
+
+ # special parameters for time_mixing in RWKV6QWEN2
+ self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
+ self.gguf_writer.add_token_shift_count(1)
+ # RWKV6QWEN2 use grouped key/value like GQA
+ self.gguf_writer.add_head_count_kv(num_key_value_heads)
+
+ # required by llama.cpp, unused
+ self.gguf_writer.add_head_count(0)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ for new_name, data in super().modify_tensors(data_torch, name, bid):
+ if "time_mix_w1" in new_name or "time_mix_w2" in new_name:
+ data = data.view(5, -1, data.shape[-1])
+ # rwkv6qwen2 has a different order of rkvwg instead of the original wkvrg
+ # permute them here to avoid code changes
+ data = torch.stack([data[3], data[1], data[2], data[0], data[4]], dim=0).view(-1, data.shape[-1])
+ if "w2" in new_name:
+ data = data.view(5, -1, data.shape[-1])
+ yield (new_name, data)
+ continue
+ yield (new_name, data)
+
+
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV6,
+ GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_UNARY,
struct ggml_tensor * td,
struct ggml_tensor * state);
+ GGML_API struct ggml_tensor * ggml_gated_linear_attn(
+ struct ggml_context * ctx,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * q,
+ struct ggml_tensor * g,
+ struct ggml_tensor * state,
+ float scale);
+
// custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
static void ggml_compute_forward_rwkv_wkv6_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
- const int64_t T = dst->src[1]->ne[3];
+ const int64_t T = dst->src[1]->ne[2];
const int64_t C = dst->ne[0];
- const int64_t HEADS = dst->src[1]->ne[2];
+ const int64_t HEADS = dst->src[1]->ne[1];
const int64_t n_seqs = dst->src[5]->ne[1];
const int64_t head_size = C / HEADS;
}
}
+// ggml_compute_forward_gla
+
+static void ggml_compute_forward_gla_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ const int64_t T = dst->src[1]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t HEADS = dst->src[1]->ne[1];
+ const int64_t n_seqs = dst->src[4]->ne[1];
+ const int64_t head_size = C / HEADS;
+ const float scale = ggml_get_op_params_f32(dst, 0);
+
+ float * dst_data = (float *) dst->data;
+ float * state = ((float *) dst->data) + C * T;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ if (ith >= HEADS) {
+ return;
+ }
+
+ const int h_start = (HEADS * ith) / nth;
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+ (HEADS * (ith + 1)) / nth : HEADS;
+
+ float * k = (float *) dst->src[0]->data;
+ float * v = (float *) dst->src[1]->data;
+ float * q = (float *) dst->src[2]->data;
+ float * g = (float *) dst->src[3]->data;
+
+ size_t t_stride = HEADS * head_size; // Same to C
+
+ size_t h_stride = C / HEADS;
+ GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+ size_t h_stride_2d = head_size * head_size;
+
+ if (ith == 0) {
+ memset(dst_data, 0, T * C * sizeof(float));
+ }
+ ggml_barrier(params->threadpool);
+
+
+ #if defined(__AVX__) && !defined(__AVX512F__)
+ #define GGML_F32X GGML_F32x8
+ #define GGML_F32X_SET1 GGML_F32x8_SET1
+ #define GGML_F32X_LOAD GGML_F32x8_LOAD
+ #define GGML_F32X_STORE GGML_F32x8_STORE
+ #define GGML_F32X_MUL GGML_F32x8_MUL
+ #define GGML_F32X_FMA GGML_F32x8_FMA
+ #define GLA_VECTOR_SIZE 8
+ #elif defined(__AVX512F__)
+ #define GGML_F32X GGML_F32x16
+ #define GGML_F32X_SET1 GGML_F32x16_SET1
+ #define GGML_F32X_LOAD GGML_F32x16_LOAD
+ #define GGML_F32X_STORE GGML_F32x16_STORE
+ #define GGML_F32X_MUL GGML_F32x16_MUL
+ #define GGML_F32X_FMA GGML_F32x16_FMA
+ #define GLA_VECTOR_SIZE 16
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
+ #define GGML_F32X GGML_F32x4
+ #define GGML_F32X_SET1 GGML_F32x4_SET1
+ #define GGML_F32X_LOAD GGML_F32x4_LOAD
+ #define GGML_F32X_STORE GGML_F32x4_STORE
+ #define GGML_F32X_MUL GGML_F32x4_MUL
+ #define GGML_F32X_FMA GGML_F32x4_FMA
+ #define GLA_VECTOR_SIZE 4
+ #endif
+
+ #ifdef GLA_VECTOR_SIZE
+ const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
+
+ for (int64_t t = 0; t < T; t++) {
+ size_t t_offset = t * t_stride;
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
+ float * state_cur = state + state_offset;
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+ for (int64_t h = h_start; h < h_end; h++) {
+ size_t h_offset = h * h_stride;
+ size_t t_h_offset = t_offset + h_offset;
+ size_t h_2d_offset = h * h_stride_2d;
+
+ for (int64_t i = 0; i < head_size; i++) {
+ size_t t_h_i_offset = t_h_offset + i;
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+ float k_val = k[t_h_i_offset];
+ float q_val = q[t_h_i_offset] * scale;
+ float g_val = g[t_h_i_offset];
+
+ // Broadcast scalar values to vectors
+ GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+ GGML_F32X q_vec = GGML_F32X_SET1(q_val);
+ GGML_F32X g_vec = GGML_F32X_SET1(g_val);
+
+ for (int64_t j = 0; j < vec_count; j++) {
+ size_t base_j = j * GLA_VECTOR_SIZE;
+ size_t t_h_j_offset = t_h_offset + base_j;
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+ // Load x elements at once
+ GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+ GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+ GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+ // Compute kv = v * k
+ GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+ // Compute temp = prev_state * g + kv
+ GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
+
+ // Update dst: dst += temp * q
+ dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
+ GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+ // Update state
+ GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
+ }
+
+ // Handle remaining elements, this will not be used.
+ for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
+ size_t t_h_j_offset = t_h_offset + j;
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
+ float v_val = v[t_h_j_offset];
+ float kv_val = v_val * k_val;
+ float prev_state_val = state_prev[h_2d_i_j_offset];
+ float temp_val = kv_val + prev_state_val * g_val;
+ dst_data[t_h_j_offset] += temp_val * q_val;
+ state_cur[h_2d_i_j_offset] = temp_val;
+ }
+ }
+ }
+ }
+
+ #else
+ for (int64_t t = 0; t < T; t++) {
+ size_t t_offset = t * t_stride;
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
+ float * state_cur = state + state_offset;
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+ for (int64_t h = h_start; h < h_end; h++) {
+ size_t h_offset = h * h_stride;
+ size_t t_h_offset = t_offset + h_offset;
+ size_t h_2d_offset = h * h_stride_2d;
+
+ for (int64_t i = 0; i < head_size; i++) {
+ size_t t_h_i_offset = t_h_offset + i;
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+ float k_val = k[t_h_i_offset];
+ float q_val = q[t_h_i_offset] * scale;
+ float g_val = g[t_h_i_offset];
+
+ for (int64_t j = 0; j < head_size; j++) {
+ size_t t_h_j_offset = t_h_offset + j;
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+ float v_val = v[t_h_j_offset];
+ float kv_val = v_val * k_val;
+ float prev_state_val = state_prev[h_2d_i_j_offset];
+ float temp_val = prev_state_val * g_val + kv_val;
+ dst_data[t_h_j_offset] += temp_val * q_val;
+ state_cur[h_2d_i_j_offset] = temp_val;
+ }
+ }
+ }
+ }
+ #endif
+}
+
+
+static void ggml_compute_forward_gla(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_gla_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_map_unary
static void ggml_compute_forward_map_unary_f32(
{
ggml_compute_forward_rwkv_wkv6(params, tensor);
} break;
+ case GGML_OP_GATED_LINEAR_ATTN:
+ {
+ ggml_compute_forward_gla(params, tensor);
+ } break;
case GGML_OP_MAP_UNARY:
{
ggml_unary_op_f32_t fun;
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
case GGML_OP_RWKV_WKV6:
+ case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv6.cuh"
+#include "ggml-cuda/gla.cuh"
#include <algorithm>
#include <array>
case GGML_OP_RWKV_WKV6:
ggml_cuda_op_rwkv_wkv6(ctx, dst);
break;
+ case GGML_OP_GATED_LINEAR_ATTN:
+ ggml_cuda_op_gated_linear_attn(ctx, dst);
+ break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV6:
+ case GGML_OP_GATED_LINEAR_ATTN:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE
--- /dev/null
+#include "common.cuh"
+#include "gla.cuh"
+
+template<int HEAD_SIZE>
+static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
+ const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
+ const int tid = threadIdx.x;
+ const int bid = blockIdx.x;
+
+ const int head_size = HEAD_SIZE;
+ const int batch_i = bid / H;
+ const int head_i = bid % H;
+ const int state_size = C * head_size;
+ const int n_seq_tokens = T / B;
+
+ float state[head_size];
+ __shared__ float _k[head_size], _r[head_size], _td[head_size];
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+ }
+
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+ __syncthreads();
+ _k[tid] = k[t];
+ _r[tid] = r[t];
+ _td[tid] = td[t];
+ __syncthreads();
+
+ const float _v = v[t];
+ float y = 0;
+ for (int j = 0; j < head_size; j += 4) {
+ const float4 & k = (float4 &)(_k[j]);
+ const float4 & r = (float4 &)(_r[j]);
+ const float4 & td = (float4 &)(_td[j]);
+ float4 & s = (float4 &)(state[j]);
+ float4 kv;
+
+ kv.x = k.x * _v;
+ kv.y = k.y * _v;
+ kv.z = k.z * _v;
+ kv.w = k.w * _v;
+
+ s.x = s.x * td.x + kv.x;
+ s.y = s.y * td.y + kv.y;
+ s.z = s.z * td.z + kv.z;
+ s.w = s.w * td.w + kv.w;
+
+ y += r.x * s.x;
+ y += r.y * s.y;
+ y += r.z * s.z;
+ y += r.w * s.w;
+ }
+ dst[t] = y * scale;
+ }
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+ }
+}
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const float * k_d = (const float *)dst->src[0]->data;
+ const float * v_d = (const float *)dst->src[1]->data;
+ const float * r_d = (const float *)dst->src[2]->data;
+ const float * td_d = (const float *)dst->src[3]->data;
+ const float * s_d = (const float *)dst->src[4]->data;
+
+ const int64_t B = dst->src[4]->ne[1];
+ const int64_t T = dst->src[0]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[1];
+
+ float scale;
+ memcpy(&scale, (float*)dst->op_params, sizeof(float));
+
+ float * dst_d = (float *)dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == 64 || C / H == 128);
+
+
+ if (C / H == 64) {
+ gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+ } else {
+ gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+ }
+}
--- /dev/null
+#include "common.cuh"
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
- const int64_t T = dst->src[0]->ne[3];
+ const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
- const int64_t H = dst->src[0]->ne[2];
+ const int64_t H = dst->src[0]->ne[1];
float * dst_d = (float *)dst->data;
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[5]->ne[1];
- const int64_t T = dst->src[0]->ne[3];
+ const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
- const int64_t H = dst->src[0]->ne[2];
+ const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
}
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
- const size_t seq_length = dst->src[0]->ne[3];
+ const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
- const size_t n_heads = dst->src[0]->ne[2];
+ const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
ggml_vk_op_f32_rwkv6(
"GET_REL_POS",
"ADD_REL_POS",
"RWKV_WKV6",
+ "GATED_LINEAR_ATTN",
"UNARY",
"OPT_STEP_ADAMW",
};
-static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"get_rel_pos(x)",
"add_rel_pos(x)",
"rwkv_wkv6(k, v, r, tf, td, s)",
+ "gated_linear_attn(k, v, q, gate, s)",
"unary(x)",
"adamw(x)",
};
-static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S = k->ne[0];
- const int64_t H = k->ne[2];
- const int64_t n_tokens = k->ne[3];
+ const int64_t H = k->ne[1];
+ const int64_t n_tokens = k->ne[2];
const int64_t n_seqs = state->ne[1];
{
- GGML_ASSERT(k->ne[1] == 1);
- GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
- GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
- // TODO: RWKV v4 and v5
- GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+ GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
+ GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
}
return result;
}
+// ggml_gated_linear_attn
+
+struct ggml_tensor * ggml_gated_linear_attn(
+ struct ggml_context * ctx,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * q,
+ struct ggml_tensor * g,
+ struct ggml_tensor * state,
+ float scale) {
+ GGML_ASSERT(ggml_is_contiguous(k));
+ GGML_ASSERT(ggml_is_contiguous(v));
+ GGML_ASSERT(ggml_is_contiguous(q));
+ GGML_ASSERT(ggml_is_contiguous(g));
+ GGML_ASSERT(ggml_is_contiguous(state));
+
+ const int64_t S = k->ne[0];
+ const int64_t H = k->ne[1];
+ const int64_t n_tokens = k->ne[2];
+ const int64_t n_seqs = state->ne[1];
+ {
+ GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+ GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
+ GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+ }
+
+ // concat output and new_state
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ ggml_set_op_params_f32(result, 0, scale);
+
+ result->op = GGML_OP_GATED_LINEAR_ATTN;
+ result->src[0] = k;
+ result->src[1] = v;
+ result->src[2] = q;
+ result->src[3] = g;
+ result->src[4] = state;
+
+ return result;
+}
+
// ggml_unary
static struct ggml_tensor * ggml_unary_impl(
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
RESIDUAL_SCALE = "{arch}.residual_scale"
EMBEDDING_SCALE = "{arch}.embedding_scale"
+ TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
GEMMA2 = auto()
STARCODER2 = auto()
RWKV6 = auto()
+ RWKV6QWEN2 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
TIME_MIX_LERP_V = auto()
TIME_MIX_LERP_R = auto()
TIME_MIX_LERP_G = auto()
+ TIME_MIX_LERP_FUSED = auto()
TIME_MIX_LERP_W = auto()
TIME_MIX_FIRST = auto()
TIME_MIX_DECAY = auto()
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
+ MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r",
MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g",
+ MODEL_TENSOR.TIME_MIX_LERP_FUSED: "blk.{bid}.time_mix_lerp_fused",
MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w",
MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first",
MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay",
MODEL_TENSOR.TIME_MIX_LERP_R,
MODEL_TENSOR.TIME_MIX_LERP_G,
MODEL_TENSOR.TIME_MIX_LERP_W,
+ MODEL_TENSOR.TIME_MIX_LERP_FUSED,
MODEL_TENSOR.TIME_MIX_FIRST,
MODEL_TENSOR.TIME_MIX_DECAY,
MODEL_TENSOR.TIME_MIX_DECAY_W1,
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE,
MODEL_TENSOR.CHANNEL_MIX_VALUE,
],
+ MODEL_ARCH.RWKV6QWEN2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.TIME_MIX_W1,
+ MODEL_TENSOR.TIME_MIX_W2,
+ MODEL_TENSOR.TIME_MIX_LERP_X,
+ MODEL_TENSOR.TIME_MIX_LERP_K,
+ MODEL_TENSOR.TIME_MIX_LERP_V,
+ MODEL_TENSOR.TIME_MIX_LERP_R,
+ MODEL_TENSOR.TIME_MIX_LERP_G,
+ MODEL_TENSOR.TIME_MIX_LERP_W,
+ MODEL_TENSOR.TIME_MIX_LERP_FUSED,
+ MODEL_TENSOR.TIME_MIX_FIRST,
+ MODEL_TENSOR.TIME_MIX_DECAY,
+ MODEL_TENSOR.TIME_MIX_DECAY_W1,
+ MODEL_TENSOR.TIME_MIX_DECAY_W2,
+ MODEL_TENSOR.TIME_MIX_KEY,
+ MODEL_TENSOR.TIME_MIX_VALUE,
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE,
+ MODEL_TENSOR.TIME_MIX_GATE,
+ MODEL_TENSOR.TIME_MIX_LN,
+ MODEL_TENSOR.TIME_MIX_OUTPUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
MODEL_ARCH.MAMBA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
def add_wkv_head_size(self, size: int) -> None:
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
+ def add_token_shift_count(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
+
def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
- "model.embed_tokens", # llama-hf nemotron olmoe olmo2
+ "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
MODEL_TENSOR.TIME_MIX_W1: (
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
+ "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_W2: (
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
+ "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_X: (
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
+ "model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_K: (
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
+ "model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_V: (
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
+ "model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_R: (
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
+ "model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_G: (
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
+ "model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LERP_W: (
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
+ "model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_FIRST: (
MODEL_TENSOR.TIME_MIX_DECAY: (
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
+ "model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_DECAY_W1: (
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
+ "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_DECAY_W2: (
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
+ "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_KEY: (
- "rwkv.blocks.{bid}.attention.key", # rwkv
+ "rwkv.blocks.{bid}.attention.key", # rwkv
+ "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_VALUE: (
- "rwkv.blocks.{bid}.attention.value", # rwkv
+ "rwkv.blocks.{bid}.attention.value", # rwkv
+ "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
"rwkv.blocks.{bid}.attention.receptance", # rwkv
+ "model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_GATE: (
- "rwkv.blocks.{bid}.attention.gate", # rwkv
+ "rwkv.blocks.{bid}.attention.gate", # rwkv
+ "model.layers.{bid}.self_attn.gate", # rwkv6qwen2
),
MODEL_TENSOR.TIME_MIX_LN: (
),
MODEL_TENSOR.TIME_MIX_OUTPUT: (
- "rwkv.blocks.{bid}.attention.output", # rwkv
+ "rwkv.blocks.{bid}.attention.output", # rwkv
+ "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
),
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" },
+ { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
+ { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
{ LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" },
{ LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" },
{ LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" },
+ { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
{ LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
{ LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
{ LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" },
},
},
+ {
+ LLM_ARCH_RWKV6QWEN2,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
+ { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
+ { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" },
+ { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
+ { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
+ { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
+ { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
+ { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" },
+ { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
+ { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
+ { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
+ { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" },
+ { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_GRANITE,
{
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
+ {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6,
+ LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON,
LLM_KV_TIME_DECAY_EXTRA_DIM,
LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE,
+ LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
LLM_TENSOR_TIME_MIX_LERP_V,
LLM_TENSOR_TIME_MIX_LERP_R,
LLM_TENSOR_TIME_MIX_LERP_G,
+ LLM_TENSOR_TIME_MIX_LERP_FUSED,
LLM_TENSOR_TIME_MIX_FIRST,
LLM_TENSOR_TIME_MIX_DECAY,
LLM_TENSOR_TIME_MIX_DECAY_W1,
uint32_t llama_hparams::n_embd_k_s() const {
if (wkv_head_size != 0) {
// for RWKV models
- return 2 * n_embd;
+ return token_shift_count * n_embd;
}
// TODO: maybe support other convolution strides than 1
uint32_t time_mix_extra_dim = 0;
uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0;
+ uint32_t token_shift_count = 2;
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
}
} break;
case LLM_ARCH_RWKV6:
+ case LLM_ARCH_RWKV6QWEN2:
{
- ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false);
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
+ ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1_6B; break;
default: model.type = e_model::MODEL_UNKNOWN;
} break;
case 61: model.type = e_model::MODEL_14B; break;
+ case 64: model.type = e_model::MODEL_32B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_T5ENCODER:
case LLM_ARCH_JAIS:
case LLM_ARCH_RWKV6:
+ case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_WAVTOKENIZER_DEC:
return LLAMA_ROPE_TYPE_NONE;
switch (model->arch) {
case LLM_ARCH_MAMBA: return true;
case LLM_ARCH_RWKV6: return true;
+ case LLM_ARCH_RWKV6QWEN2: return true;
default: return false;
}
}
struct ggml_tensor * time_mix_lerp_v = nullptr;
struct ggml_tensor * time_mix_lerp_r = nullptr;
struct ggml_tensor * time_mix_lerp_g = nullptr;
-
- struct ggml_tensor * time_mix_first = nullptr;
- struct ggml_tensor * time_mix_decay = nullptr;
- struct ggml_tensor * time_mix_decay_w1 = nullptr;
- struct ggml_tensor * time_mix_decay_w2 = nullptr;
- struct ggml_tensor * time_mix_key = nullptr;
- struct ggml_tensor * time_mix_value = nullptr;
- struct ggml_tensor * time_mix_receptance = nullptr;
- struct ggml_tensor * time_mix_gate = nullptr;
+ struct ggml_tensor * time_mix_lerp_fused = nullptr;
+
+ struct ggml_tensor * time_mix_first = nullptr;
+ struct ggml_tensor * time_mix_decay = nullptr;
+ struct ggml_tensor * time_mix_decay_w1 = nullptr;
+ struct ggml_tensor * time_mix_decay_w2 = nullptr;
+ struct ggml_tensor * time_mix_key = nullptr;
+ struct ggml_tensor * time_mix_key_b = nullptr;
+ struct ggml_tensor * time_mix_value = nullptr;
+ struct ggml_tensor * time_mix_value_b = nullptr;
+ struct ggml_tensor * time_mix_receptance = nullptr;
+ struct ggml_tensor * time_mix_receptance_b = nullptr;
+ struct ggml_tensor * time_mix_gate = nullptr;
struct ggml_tensor * time_mix_ln = nullptr;
struct ggml_tensor * time_mix_ln_b = nullptr;
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
- // sanity checks
+ // sanity checks for models that have attention layers
+ if (qs.n_attention_wv != 0)
{
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
// attention layers have a non-zero number of kv heads
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
+ quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
// do not quantize relative position bias (T5)
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
const int64_t H = 123;
const int64_t n_tokens = 123;
const int64_t n_seqs = 123;
- ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
- ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
- ggml_tensor * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
+ ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
+ ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
+ ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
ggml_tensor * tf = w;
- ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
+ ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
} break;
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
- layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
- layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
- layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
- layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
- layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
+ layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
}
} break;
+ case LLM_ARCH_RWKV6QWEN2:
+ {
+ model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+ const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+ const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+ const int head_size = hparams.wkv_head_size;
+ const int attn_hidden_size = n_embd;
+ const int n_head_kv = hparams.n_head_kv();
+ int attn_key_value_size;
+ if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) {
+ attn_key_value_size = attn_hidden_size;
+ } else {
+ attn_key_value_size = n_head_kv * head_size;
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
+ layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
+
+ layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
+ layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
+
+ layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
+ layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
+ layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
+ layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0);
+ layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0);
+ layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+ layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
+ // optional bias tensors
+ layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+ layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ }
+ } break;
case LLM_ARCH_CHAMELEON:
{
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
const struct llama_layer * layer,
struct ggml_tensor * cur,
struct ggml_tensor * x_prev,
- struct ggml_tensor ** wkv_state) {
+ struct ggml_tensor ** wkv_state,
+ size_t wkv_head_size,
+ size_t head_count_kv) {
size_t n_embd = cur->ne[0];
size_t n_seq_tokens = cur->ne[1];
size_t n_seqs = cur->ne[2];
- size_t head_size = layer->time_mix_first->ne[0];
- size_t head_count = layer->time_mix_first->ne[1];
+ size_t head_size = wkv_head_size;
+ size_t head_count = n_embd / head_size;
size_t n_tokens = n_seqs * n_seq_tokens;
+ bool is_qrwkv = layer->time_mix_first == nullptr;
+
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
sx = ggml_reshape_2d(ctx, sx, n_embd, n_tokens);
xxx
);
- struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
- struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
- struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
- struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
- struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
-
- struct ggml_tensor * xw = ggml_add(
- ctx,
- ggml_mul(
- ctx,
- ggml_add(ctx, mw, layer->time_mix_lerp_w),
- sx
- ),
- cur
- );
+ struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
+ if (layer->time_mix_lerp_fused) {
+ // fusing these weights makes some performance improvement
+ sx = ggml_reshape_3d(ctx, sx, n_embd, 1, n_tokens);
+ cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
+ xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
+ xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+ xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+ xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+ xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+ xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+ } else {
+ // for backward compatibility
+ xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+ xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+ xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+ xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+ xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
- struct ggml_tensor * xk = ggml_add(
- ctx,
- ggml_mul(
- ctx,
- ggml_add(ctx, mk, layer->time_mix_lerp_k),
- sx
- ),
- cur
- );
+ xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
+ xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
+ xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
+ xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
+ xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
+ }
- struct ggml_tensor * xv = ggml_add(
- ctx,
- ggml_mul(
- ctx,
- ggml_add(ctx, mv, layer->time_mix_lerp_v),
- sx
- ),
- cur
- );
+ struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
+ struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk);
+ struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv);
+ if (layer->time_mix_receptance_b) {
+ r = ggml_add(ctx, r, layer->time_mix_receptance_b);
+ }
+ if (layer->time_mix_key_b) {
+ k = ggml_add(ctx, k, layer->time_mix_key_b);
+ }
+ if (layer->time_mix_value_b) {
+ v = ggml_add(ctx, v, layer->time_mix_value_b);
+ }
- struct ggml_tensor * xr = ggml_add(
- ctx,
- ggml_mul(
- ctx,
- ggml_add(ctx, mr, layer->time_mix_lerp_r),
- sx
- ),
- cur
- );
+ struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg);
+ if (is_qrwkv) {
+ g = ggml_sigmoid(ctx, g);
+ } else {
+ g = ggml_silu(ctx, g);
+ }
- struct ggml_tensor * xg = ggml_add(
- ctx,
- ggml_mul(
- ctx,
- ggml_add(ctx, mg, layer->time_mix_lerp_g),
- sx
- ),
- cur
- );
+ if (head_count_kv != head_count) {
+ GGML_ASSERT(head_count % head_count_kv == 0);
+ k = ggml_reshape_4d(ctx, k, head_size, 1, head_count_kv, n_tokens);
+ v = ggml_reshape_4d(ctx, v, head_size, 1, head_count_kv, n_tokens);
+ struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens);
+ k = ggml_repeat(ctx, k, tmp);
+ v = ggml_repeat(ctx, v, tmp);
+ }
- struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
- struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
- struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
- struct ggml_tensor * g = ggml_silu(
- ctx,
- llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
- );
+ k = ggml_reshape_3d(ctx, k, head_size, head_count, n_tokens);
+ v = ggml_reshape_3d(ctx, v, head_size, head_count, n_tokens);
+ r = ggml_reshape_3d(ctx, r, head_size, head_count, n_tokens);
struct ggml_tensor * w = ggml_mul_mat(
ctx,
)
);
- w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
+ w = ggml_add(ctx, w, layer->time_mix_decay);
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
- w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
+ w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens);
- k = ggml_transpose(ctx, k);
- v = ggml_transpose(ctx, v);
- r = ggml_transpose(ctx, r);
+ if (is_qrwkv) {
+ // k = k * (1 - w)
+ k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
+ }
- struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+ struct ggml_tensor * wkv_output;
+ if (!layer->time_mix_first) {
+ wkv_output = ggml_gated_linear_attn(ctx, k, v, r, w, *wkv_state, pow(head_size, -0.5f));
+ } else {
+ wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+ }
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
- // group norm with head_count groups
- cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
- cur = ggml_norm(ctx, cur, 64e-5f);
+ if (!is_qrwkv) {
+ // group norm with head_count groups
+ cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
+ cur = ggml_norm(ctx, cur, 64e-5f);
- // Convert back to regular vectors.
- cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
- cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+ // Convert back to regular vectors.
+ cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+ cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+ } else {
+ cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+ }
cur = ggml_mul(ctx, cur, g);
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
1
);
- cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
+ cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size));
ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(
gf,
return gf;
}
+ // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
+ ggml_cgraph * build_rwkv6qwen2() {
+ ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+ GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+
+ const int64_t n_seqs = ubatch.n_seqs;
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+ const int64_t n_tokens = ubatch.n_tokens;
+ GGML_ASSERT(n_seqs != 0);
+ GGML_ASSERT(ubatch.equal_seqs);
+ GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+ struct ggml_tensor * state_copy = build_inp_s_copy();
+ struct ggml_tensor * state_mask = build_inp_s_mask();
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+ for (int il = 0; il < n_layer; ++il) {
+ const llama_layer * layer = &model.layers[il];
+
+ // (ab)using the KV cache to store the states
+ struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
+ gf, kv_self.k_l[il], state_copy, state_mask,
+ hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
+ struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
+ gf, kv_self.v_l[il], state_copy, state_mask,
+ hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
+
+ cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+ token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
+
+ struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
+ struct ggml_tensor * x_prev = ggml_concat(
+ ctx0,
+ token_shift,
+ ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
+ 1
+ );
+
+ ggml_build_forward_expand(
+ gf,
+ ggml_cpy(
+ ctx0,
+ wkv_states,
+ ggml_view_1d(
+ ctx0,
+ kv_self.v_l[il],
+ hparams.n_embd_v_s() * n_seqs,
+ hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+ )
+ )
+ );
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv()));
+ ggml_build_forward_expand(gf, ffn_inp);
+ ggml_build_forward_expand(
+ gf,
+ ggml_cpy(
+ ctx0,
+ wkv_states,
+ ggml_view_1d(
+ ctx0,
+ kv_self.v_l[il],
+ hparams.n_embd_v_s() * n_seqs,
+ hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+ )
+ )
+ );
+
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, lctx, cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+ cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
// ref: https://github.com/facebookresearch/chameleon
// based on the original build_llama() function, changes:
// * qk-norm
{
result = llm.build_rwkv6();
} break;
+ case LLM_ARCH_RWKV6QWEN2:
+ {
+ result = llm.build_rwkv6qwen2();
+ } break;
case LLM_ARCH_CHAMELEON:
{
result = llm.build_chameleon();
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t n_tokens = n_seq_tokens * n_seqs;
- ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
- ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
- ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
+ ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+ ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+ ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
- ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
+ ggml_tensor * td = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
return out;
}
};
+// GGML_OP_GATED_LINEAR_ATTN
+struct test_gla : public test_case {
+ const ggml_type type;
+
+ const int64_t head_count;
+ const int64_t head_size;
+ const int64_t n_seq_tokens;
+ const int64_t n_seqs;
+
+ std::string vars() override {
+ return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
+ }
+
+ test_gla(ggml_type type = GGML_TYPE_F32,
+ int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
+ : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ const int64_t n_tokens = n_seq_tokens * n_seqs;
+ ggml_tensor * q = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+ ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+ ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+ ggml_tensor * g = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
+ ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
+ ggml_tensor * out = ggml_gated_linear_attn(ctx, k, v, q, g, s, pow(head_size, -0.5));
+ return out;
+ }
+};
+
// GGML_OP_MUL_MAT
struct test_mul_mat : public test_case {
const ggml_type type_a;
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
+ test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
+ test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
+ test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
+ test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
+
for (int i = 1; i < 9; ++i) {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));