from __future__ import annotations
+import ast
import logging
import argparse
import contextlib
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
+ gguf.MODEL_TENSOR.TIME_MIX_FIRST,
+ gguf.MODEL_TENSOR.TIME_MIX_W1,
+ gguf.MODEL_TENSOR.TIME_MIX_W2,
)
)
- or not name.endswith(".weight")
+ or not new_name.endswith(".weight")
):
data_qtype = gguf.GGMLQuantizationType.F32
model_arch = gguf.MODEL_ARCH.STARCODER2
+@Model.register("Rwkv6ForCausalLM")
+class Rwkv6Model(Model):
+ model_arch = gguf.MODEL_ARCH.RWKV6
+
+ def set_vocab(self):
+ assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
+ vocab_size = self.hparams.get("vocab_size", 65536)
+
+ tokens: list[bytes] = ['<s>'.encode("utf-8")]
+ toktypes: list[int] = [gguf.TokenType.CONTROL]
+
+ with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
+ lines = f.readlines()
+ for line in lines:
+ parts = line.split(' ')
+ assert len(parts) >= 3
+ token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
+ token = token.encode("utf-8") if isinstance(token, str) else token
+ assert isinstance(token, bytes)
+ assert len(token) == token_len
+ token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
+ tokens.append(token_text.encode("utf-8"))
+ toktypes.append(gguf.TokenType.NORMAL)
+ remainder = vocab_size - len(tokens)
+ assert remainder >= 0
+ for i in range(len(tokens), vocab_size):
+ tokens.append(f"[PAD{i}]".encode("utf-8"))
+ toktypes.append(gguf.TokenType.UNUSED)
+
+ self.gguf_writer.add_tokenizer_model("rwkv")
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+
+ def set_gguf_parameters(self):
+ block_count = self.hparams["num_hidden_layers"]
+ head_size = self.hparams["head_size"]
+ hidden_size = self.hparams["hidden_size"]
+ layer_norm_eps = self.hparams["layer_norm_epsilon"]
+ rescale_every_n_layers = self.hparams["rescale_every"]
+ intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else int((hidden_size * 3.5) // 32 * 32)
+ time_mix_extra_dim = 64 if hidden_size == 4096 else 32
+ time_decay_extra_dim = 128 if hidden_size == 4096 else 64
+
+ # RWKV isn't context limited
+ self.gguf_writer.add_context_length(1048576)
+ self.gguf_writer.add_embedding_length(hidden_size)
+ self.gguf_writer.add_block_count(block_count)
+ self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
+ self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
+ self.gguf_writer.add_wkv_head_size(head_size)
+ self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
+ self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
+ self.gguf_writer.add_feed_forward_length(intermediate_size)
+ self.gguf_writer.add_file_type(self.ftype)
+
+ # required by llama.cpp, unused
+ self.gguf_writer.add_head_count(0)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ new_name = self.map_tensor_name(name)
+
+ if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
+ new_name += ".weight"
+
+ if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"):
+ data_torch = data_torch.transpose(0, 1)
+
+ if new_name.endswith("time_mix_w2.weight"):
+ data_torch = data_torch.permute(0, 2, 1)
+
+ rescale_every_n_layers = self.hparams["rescale_every"]
+ if rescale_every_n_layers > 0:
+ if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
+ data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
+
+ yield (new_name, data_torch)
+
+
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA
GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
+ GGML_OP_RWKV_WKV,
GGML_OP_UNARY,
GGML_UNARY_OP_SILU,
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
+ GGML_UNARY_OP_EXP,
GGML_UNARY_OP_COUNT,
};
struct ggml_context * ctx,
struct ggml_tensor * a);
+ GGML_API struct ggml_tensor * ggml_exp(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_exp_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
struct ggml_tensor * pw,
struct ggml_tensor * ph);
+ GGML_API struct ggml_tensor * ggml_rwkv_wkv(
+ struct ggml_context * ctx,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * r,
+ struct ggml_tensor * tf,
+ struct ggml_tensor * td,
+ struct ggml_tensor * state);
+
// custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
// TODO: optimize performance
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
"WIN_UNPART",
"GET_REL_POS",
"ADD_REL_POS",
+ "RWKV_WKV",
"UNARY",
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"win_unpart(x)",
"get_rel_pos(x)",
"add_rel_pos(x)",
+ "rwkv_wkv(k, v, r, tf, td, s)",
"unary(x)",
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
"SILU",
"HARDSWISH",
"HARDSIGMOID",
+ "EXP",
};
-static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
+static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
}
+// ggml exp
+struct ggml_tensor * ggml_exp(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
+}
+
+struct ggml_tensor * ggml_exp_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
+}
+
// ggml_norm
static struct ggml_tensor * ggml_norm_impl(
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
}
+// ggml_rwkv_wkv
+
+struct ggml_tensor * ggml_rwkv_wkv(
+ struct ggml_context * ctx,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * r,
+ struct ggml_tensor * tf,
+ struct ggml_tensor * td,
+ struct ggml_tensor * state) {
+ GGML_ASSERT(ggml_is_contiguous(k));
+ GGML_ASSERT(ggml_is_contiguous(v));
+ GGML_ASSERT(ggml_is_contiguous(r));
+ GGML_ASSERT(ggml_is_contiguous(tf));
+ GGML_ASSERT(ggml_is_contiguous(td));
+ GGML_ASSERT(ggml_is_contiguous(state));
+
+ const int64_t S = k->ne[0];
+ const int64_t H = k->ne[2];
+ const int64_t n_tokens = k->ne[3];
+ const int64_t n_seqs = state->ne[1];
+ {
+ GGML_ASSERT(k->ne[1] == 1);
+ GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
+ GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
+ // TODO: RWKV v4 and v5
+ GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+ }
+
+ bool is_node = false;
+
+ if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
+ GGML_ABORT("fatal error"); // TODO: implement backward
+ is_node = true;
+ }
+
+ // concat output and new_state
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ result->op = GGML_OP_RWKV_WKV;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = k;
+ result->src[1] = v;
+ result->src[2] = r;
+ result->src[3] = tf;
+ result->src[4] = td;
+ result->src[5] = state;
+
+ return result;
+}
+
// ggml_unary
static struct ggml_tensor * ggml_unary_impl(
}
}
+static void ggml_compute_forward_exp_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_exp_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_exp(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_exp_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_norm
{
ggml_compute_forward_hardsigmoid(params, dst);
} break;
+ case GGML_UNARY_OP_EXP:
+ {
+ ggml_compute_forward_exp(params, dst);
+ } break;
default:
{
GGML_ABORT("fatal error");
}
}
+// ggml_compute_forward_rwkv_wkv
+
+static void ggml_compute_forward_rwkv_wkv_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ const size_t T = dst->src[1]->ne[3];
+ const size_t C = dst->ne[0];
+ const size_t H = dst->src[1]->ne[2];
+ const size_t n_seqs = dst->src[5]->ne[1];
+
+ float * dst_data = (float *) dst->data;
+ float * state = ((float *) dst->data) + C * T;
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ memset(dst_data, 0, T * C * sizeof(float));
+
+ float * k = (float *) dst->src[0]->data;
+ float * v = (float *) dst->src[1]->data;
+ float * r = (float *) dst->src[2]->data;
+ float * time_faaaa = (float *) dst->src[3]->data;
+ float * time_decay = (float *) dst->src[4]->data;
+
+ size_t t_stride = H * (C / H);
+
+ size_t h_stride = C / H;
+ size_t h_stride_2d = (C / H) * (C / H);
+
+ // basically fused operations:
+ // dst = r @ (time_faaaa * (k @ v) + state),
+ // state = time_decay * state + (k @ v),
+ // recursive through each token
+ for (size_t t = 0; t < T; t++) {
+ size_t t_offset = t * t_stride;
+ size_t state_offset = (C / H) * C * (t / (T / n_seqs));
+ float * state_cur = state + state_offset;
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+ for (size_t h = 0; h < H; h++) {
+ size_t h_offset = h * h_stride;
+ size_t t_h_offset = t_offset + h_offset;
+ size_t h_2d_offset = h * h_stride_2d;
+
+ for (size_t i = 0; i < C / H; i++) {
+ size_t t_h_i_offset = t_h_offset + i;
+ size_t h_i_offset = h_offset + i;
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+ float k_val = k[t_h_i_offset];
+ float r_val = r[t_h_i_offset];
+ float time_faaaa_val = time_faaaa[h_i_offset];
+ // RWKV v6: different time_decay for each token.
+ float time_decay_val = time_decay[t_h_i_offset];
+
+ for (size_t j = 0; j < C / H; j ++) {
+ size_t t_h_j_offset = t_h_offset + j;
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+ float v_val = v[t_h_j_offset];
+ float kv_val = v_val * k_val;
+ float prev_state_val = state_prev[h_2d_i_j_offset];
+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
+ dst_data[t_h_j_offset] += temp_val * r_val;
+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_rwkv_wkv(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_rwkv_wkv_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_map_unary
static void ggml_compute_forward_map_unary_f32(
{
ggml_compute_forward_add_rel_pos(params, tensor);
} break;
+ case GGML_OP_RWKV_WKV:
+ {
+ ggml_compute_forward_rwkv_wkv(params, tensor);
+ } break;
case GGML_OP_MAP_UNARY:
{
ggml_unary_op_f32_t fun;
zero_table);
}
} break;
+ case GGML_UNARY_OP_EXP:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_mul(ctx, tensor, tensor->grad),
+ zero_table);
+ }
+ } break;
default:
GGML_ABORT("fatal error");
}
} break;
case GGML_OP_GET_REL_POS:
case GGML_OP_ADD_REL_POS:
+ case GGML_OP_RWKV_WKV:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_EXP:
{
n_tasks = 1;
} break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
+ case GGML_OP_RWKV_WKV:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
+ RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
+ TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim"
+ TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
+ class WKV:
+ HEAD_SIZE = "{arch}.wkv.head_size"
+
class Tokenizer:
MODEL = "tokenizer.ggml.model"
PRE = "tokenizer.ggml.pre"
GEMMA = auto()
GEMMA2 = auto()
STARCODER2 = auto()
+ RWKV6 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
SSM_A = auto()
SSM_D = auto()
SSM_OUT = auto()
+ TIME_MIX_W1 = auto()
+ TIME_MIX_W2 = auto()
+ TIME_MIX_LERP_X = auto()
+ TIME_MIX_LERP_K = auto()
+ TIME_MIX_LERP_V = auto()
+ TIME_MIX_LERP_R = auto()
+ TIME_MIX_LERP_G = auto()
+ TIME_MIX_LERP_W = auto()
+ TIME_MIX_FIRST = auto()
+ TIME_MIX_DECAY = auto()
+ TIME_MIX_DECAY_W1 = auto()
+ TIME_MIX_DECAY_W2 = auto()
+ TIME_MIX_KEY = auto()
+ TIME_MIX_VALUE = auto()
+ TIME_MIX_RECEPTANCE = auto()
+ TIME_MIX_GATE = auto()
+ TIME_MIX_LN = auto()
+ TIME_MIX_OUTPUT = auto()
+ CHANNEL_MIX_LERP_K = auto()
+ CHANNEL_MIX_LERP_R = auto()
+ CHANNEL_MIX_KEY = auto()
+ CHANNEL_MIX_RECEPTANCE = auto()
+ CHANNEL_MIX_VALUE = auto()
ATTN_Q_A = auto()
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2",
+ MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
- MODEL_TENSOR.TOKEN_EMBD: "token_embd",
- MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
- MODEL_TENSOR.TOKEN_TYPES: "token_types",
- MODEL_TENSOR.POS_EMBD: "position_embd",
- MODEL_TENSOR.OUTPUT_NORM: "output_norm",
- MODEL_TENSOR.OUTPUT: "output",
- MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
- MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
- MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
- MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
- MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
- MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
- MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
- MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
- MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
- MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
- MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
- MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
- MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
- MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
- MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm",
- MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
- MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
- MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
- MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
- MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
- MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
- MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
- MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
- MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
- MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
- MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
- MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
- MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
- MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
- MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
- MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
- MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
- MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
- MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
- MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
- MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
- MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
- MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
- MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
- MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
- MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
- MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
- MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
- MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
- MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
- MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
- MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
- MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm",
- MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q",
- MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k",
- MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v",
- MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o",
- MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b",
- MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm",
- MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q",
- MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k",
- MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v",
- MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o",
- MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b",
- MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm",
- MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate",
- MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down",
- MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up",
- MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm",
- MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm",
- MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q",
- MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k",
- MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v",
- MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o",
- MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b",
- MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm",
- MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate",
- MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
- MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
- MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
+ MODEL_TENSOR.TOKEN_EMBD: "token_embd",
+ MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
+ MODEL_TENSOR.TOKEN_TYPES: "token_types",
+ MODEL_TENSOR.POS_EMBD: "position_embd",
+ MODEL_TENSOR.OUTPUT_NORM: "output_norm",
+ MODEL_TENSOR.OUTPUT: "output",
+ MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
+ MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
+ MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
+ MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
+ MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
+ MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
+ MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
+ MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
+ MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
+ MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
+ MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
+ MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
+ MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
+ MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
+ MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm",
+ MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
+ MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
+ MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
+ MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
+ MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
+ MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
+ MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
+ MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
+ MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
+ MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
+ MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
+ MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
+ MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
+ MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
+ MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
+ MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
+ MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
+ MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
+ MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
+ MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
+ MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
+ MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
+ MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
+ MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
+ MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
+ MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
+ MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
+ MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
+ MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r",
+ MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g",
+ MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w",
+ MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first",
+ MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay",
+ MODEL_TENSOR.TIME_MIX_DECAY_W1: "blk.{bid}.time_mix_decay_w1",
+ MODEL_TENSOR.TIME_MIX_DECAY_W2: "blk.{bid}.time_mix_decay_w2",
+ MODEL_TENSOR.TIME_MIX_KEY: "blk.{bid}.time_mix_key",
+ MODEL_TENSOR.TIME_MIX_VALUE: "blk.{bid}.time_mix_value",
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE: "blk.{bid}.time_mix_receptance",
+ MODEL_TENSOR.TIME_MIX_GATE: "blk.{bid}.time_mix_gate",
+ MODEL_TENSOR.TIME_MIX_LN: "blk.{bid}.time_mix_ln",
+ MODEL_TENSOR.TIME_MIX_OUTPUT: "blk.{bid}.time_mix_output",
+ MODEL_TENSOR.CHANNEL_MIX_LERP_K: "blk.{bid}.channel_mix_lerp_k",
+ MODEL_TENSOR.CHANNEL_MIX_LERP_R: "blk.{bid}.channel_mix_lerp_r",
+ MODEL_TENSOR.CHANNEL_MIX_KEY: "blk.{bid}.channel_mix_key",
+ MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: "blk.{bid}.channel_mix_receptance",
+ MODEL_TENSOR.CHANNEL_MIX_VALUE: "blk.{bid}.channel_mix_value",
+ MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
+ MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
+ MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
+ MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
+ MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
+ MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
+ MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
+ MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
+ MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm",
+ MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q",
+ MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k",
+ MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v",
+ MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o",
+ MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b",
+ MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm",
+ MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q",
+ MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k",
+ MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v",
+ MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o",
+ MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b",
+ MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm",
+ MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate",
+ MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down",
+ MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up",
+ MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm",
+ MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm",
+ MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q",
+ MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k",
+ MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v",
+ MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o",
+ MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b",
+ MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm",
+ MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate",
+ MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
+ MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
+ MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.RWKV6: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_NORM_2,
+ MODEL_TENSOR.TIME_MIX_W1,
+ MODEL_TENSOR.TIME_MIX_W2,
+ MODEL_TENSOR.TIME_MIX_LERP_X,
+ MODEL_TENSOR.TIME_MIX_LERP_K,
+ MODEL_TENSOR.TIME_MIX_LERP_V,
+ MODEL_TENSOR.TIME_MIX_LERP_R,
+ MODEL_TENSOR.TIME_MIX_LERP_G,
+ MODEL_TENSOR.TIME_MIX_LERP_W,
+ MODEL_TENSOR.TIME_MIX_FIRST,
+ MODEL_TENSOR.TIME_MIX_DECAY,
+ MODEL_TENSOR.TIME_MIX_DECAY_W1,
+ MODEL_TENSOR.TIME_MIX_DECAY_W2,
+ MODEL_TENSOR.TIME_MIX_KEY,
+ MODEL_TENSOR.TIME_MIX_VALUE,
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE,
+ MODEL_TENSOR.TIME_MIX_GATE,
+ MODEL_TENSOR.TIME_MIX_LN,
+ MODEL_TENSOR.TIME_MIX_OUTPUT,
+ MODEL_TENSOR.CHANNEL_MIX_LERP_K,
+ MODEL_TENSOR.CHANNEL_MIX_LERP_R,
+ MODEL_TENSOR.CHANNEL_MIX_KEY,
+ MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE,
+ MODEL_TENSOR.CHANNEL_MIX_VALUE,
+ ],
MODEL_ARCH.MAMBA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
def add_expert_weights_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
+ def add_rescale_every_n_layers(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
+
+ def add_time_mix_extra_dim(self, dim: int) -> None:
+ self.add_uint32(Keys.LLM.TIME_MIX_EXTRA_DIM.format(arch=self.arch), dim)
+
+ def add_time_decay_extra_dim(self, dim: int) -> None:
+ self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)
+
+ def add_wkv_head_size(self, size: int) -> None:
+ self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
+
def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
"embedding.word_embeddings", # chatglm
"transformer.token_embeddings", # openelm
"shared", # t5
+ "rwkv.embeddings", # rwkv
),
# Token type embeddings
"embeddings.LayerNorm", # bert
"emb_ln", # nomic-bert
"transformer.norm", # openelm
+ "rwkv.blocks.0.pre_ln", # rwkv
),
# Position embeddings
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
"output_layer", # chatglm
+ "head", # rwkv
),
# Output norm
"encoder.final_layernorm", # chatglm
"transformer.norm", # openelm
"model.norm", # nemotron
+ "rwkv.ln_out", # rwkv
),
# Rope frequencies
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm
+ "rwkv.blocks.{bid}.ln1", # rwkv
),
# Attention norm 2
MODEL_TENSOR.ATTN_NORM_2: (
- "transformer.h.{bid}.ln_attn", # falcon40b
+ "transformer.h.{bid}.ln_attn", # falcon40b
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
+ "rwkv.blocks.{bid}.ln2", # rwkv
),
# Attention query-key-value
"backbone.layers.{bid}.mixer.out_proj",
),
+ MODEL_TENSOR.TIME_MIX_W1: (
+ "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_W2: (
+ "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_X: (
+ "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_K: (
+ "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_V: (
+ "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_R: (
+ "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_G: (
+ "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_W: (
+ "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_FIRST: (
+ "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_DECAY: (
+ "rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_DECAY_W1: (
+ "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_DECAY_W2: (
+ "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_KEY: (
+ "rwkv.blocks.{bid}.attention.key", # rwkv
+ ),
+
+ MODEL_TENSOR.TIME_MIX_VALUE: (
+ "rwkv.blocks.{bid}.attention.value", # rwkv
+ ),
+
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
+ "rwkv.blocks.{bid}.attention.receptance", # rwkv
+ ),
+
+ MODEL_TENSOR.TIME_MIX_GATE: (
+ "rwkv.blocks.{bid}.attention.gate", # rwkv
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LN: (
+ "rwkv.blocks.{bid}.attention.ln_x", # rwkv
+ ),
+
+ MODEL_TENSOR.TIME_MIX_OUTPUT: (
+ "rwkv.blocks.{bid}.attention.output", # rwkv
+ ),
+
+ MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
+ "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
+ ),
+
+ MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
+ "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6
+ ),
+
+ MODEL_TENSOR.CHANNEL_MIX_KEY: (
+ "rwkv.blocks.{bid}.feed_forward.key", # rwkv
+ ),
+
+ MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
+ "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv
+ ),
+
+ MODEL_TENSOR.CHANNEL_MIX_VALUE: (
+ "rwkv.blocks.{bid}.feed_forward.value", # rwkv
+ ),
+
MODEL_TENSOR.ATTN_Q_A: (
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
),
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
+ LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
};
// pre-tokenization types
auto res = children.find(c);
if (res != children.end()) {
return res->second.get_longest_prefix(key, len, offset + 1);
- } else {
- return std::make_pair(key, offset);
}
+
+ return std::make_pair(key, offset);
}
- struct naive_trie * traverse(const char c) {
+ const struct naive_trie * traverse(const char c) const {
auto res = children.find(c);
if (res != children.end()) {
return &res->second;
- } else {
- return NULL;
}
+
+ return NULL;
}
std::map<char, struct naive_trie> children;
bool has_value;
// traverse the token matcher trie to find a matching token
bool single_codepoint_token_found = false;
const struct best_tokenization & current_best = tokenization_results[input_offset];
- struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
+ const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
while (prefix_offset <= input_len && node != NULL) {
// check if we found valid token in prefix
struct naive_trie token_matcher;
};
+//
+// RWKV tokenizer
+//
+
+static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
+ std::vector<uint8_t> output;
+ output.reserve(escaped.size());
+
+ // Parser state
+ bool escaping = false;
+ uint8_t hex_remaining = 0;
+ uint8_t hex_acc = 0;
+
+ // Step through characters, performing parsing
+ for (const char & c : escaped) {
+ // If we're parsing a hex code, interpret the next character
+ if (hex_remaining != 0) {
+ uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
+ hex_acc = (hex_acc << 4) + value;
+
+ hex_remaining -= 1;
+ if (hex_remaining == 0) {
+ output.push_back(hex_acc);
+ hex_acc = 0;
+ }
+
+ continue;
+ }
+
+ // If we got an escape character, interpret it
+ if (escaping) {
+ if (c == 't') {
+ output.push_back('\t');
+ } else if (c == 'n') {
+ output.push_back('\n');
+ } else if (c == 'r') {
+ output.push_back('\r');
+ } else if (c == 'x') {
+ hex_remaining = 2;
+ } else {
+ output.push_back(c);
+ }
+
+ escaping = false;
+ continue;
+ }
+
+ if (c == '\\') {
+ escaping = true;
+ continue;
+ }
+
+ output.push_back(c);
+ }
+
+ return output;
+}
+
+struct llm_tokenizer_rwkv {
+ llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
+ // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
+ // For now, we decode the vocab here into the lookup we'll use for tokenization.
+
+ // build trie
+ for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
+ const auto & token = vocab.id_to_token[id];
+ const auto data = llama_unescape_rwkv_token(token.text);
+ token_matcher.insert((const char *) data.data(), data.size(), id);
+ }
+ }
+
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+ uint32_t position = 0;
+
+ while (position < text.size()) {
+ const struct naive_trie * node = token_matcher.traverse(text[position]);
+ if (node == NULL) {
+ // no matching token found, add unknown token
+ output.push_back(vocab.special_unk_id);
+ position += 1;
+ continue;
+ }
+
+ // traverse the trie to find the longest matching token
+ uint32_t token_id = 0;
+ uint32_t token_length = 0;
+ while (node != NULL) {
+ if (node->has_value) {
+ token_id = node->value;
+ token_length = position + 1;
+ }
+ node = node->traverse(text[++position]);
+ }
+
+ // add the longest matching token
+ output.push_back(token_id);
+ position = token_length;
+ }
+ }
+
+ const llama_vocab & vocab;
+
+ struct naive_trie token_matcher;
+};
+
//
// (de-) tokenize
//
output.push_back(vocab.special_eos_id);
}
} break;
+ case LLAMA_VOCAB_TYPE_RWKV:
+ {
+ for (const auto & fragment : fragment_buffer) {
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+
+ llm_tokenizer_rwkv tokenizer(vocab);
+ tokenizer.tokenize(raw_text, output);
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+ output.push_back(fragment.token);
+ }
+ }
+ } break;
case LLAMA_VOCAB_TYPE_NONE:
GGML_ABORT("fatal error");
}
}
break;
}
+ case LLAMA_VOCAB_TYPE_RWKV: {
+ std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
+
+ // If we don't have enough space, return an error
+ if (result.size() > (size_t)length) {
+ return -(int)result.size();
+ }
+
+ memcpy(buf, result.data(), result.size());
+ return (int)result.size();
+ }
default:
GGML_ABORT("fatal error");
}
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE,
+ LLM_ARCH_RWKV6,
LLM_ARCH_UNKNOWN,
};
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" },
+ { LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
+ LLM_KV_RESCALE_EVERY_N_LAYERS,
+ LLM_KV_TIME_MIX_EXTRA_DIM,
+ LLM_KV_TIME_DECAY_EXTRA_DIM,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_DT_B_C_RMS,
+ LLM_KV_WKV_HEAD_SIZE,
+
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_PRE,
LLM_KV_TOKENIZER_LIST,
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
- { LLM_KV_POOLING_TYPE , "%s.pooling_type" },
+ { LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
+ { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
+ { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" },
+ { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
+ { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
+
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_OUT,
+ LLM_TENSOR_TIME_MIX_W1,
+ LLM_TENSOR_TIME_MIX_W2,
+ LLM_TENSOR_TIME_MIX_LERP_X,
+ LLM_TENSOR_TIME_MIX_LERP_W,
+ LLM_TENSOR_TIME_MIX_LERP_K,
+ LLM_TENSOR_TIME_MIX_LERP_V,
+ LLM_TENSOR_TIME_MIX_LERP_R,
+ LLM_TENSOR_TIME_MIX_LERP_G,
+ LLM_TENSOR_TIME_MIX_FIRST,
+ LLM_TENSOR_TIME_MIX_DECAY,
+ LLM_TENSOR_TIME_MIX_DECAY_W1,
+ LLM_TENSOR_TIME_MIX_DECAY_W2,
+ LLM_TENSOR_TIME_MIX_KEY,
+ LLM_TENSOR_TIME_MIX_VALUE,
+ LLM_TENSOR_TIME_MIX_RECEPTANCE,
+ LLM_TENSOR_TIME_MIX_GATE,
+ LLM_TENSOR_TIME_MIX_LN,
+ LLM_TENSOR_TIME_MIX_OUTPUT,
+ LLM_TENSOR_CHANNEL_MIX_LERP_K,
+ LLM_TENSOR_CHANNEL_MIX_LERP_R,
+ LLM_TENSOR_CHANNEL_MIX_KEY,
+ LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,
+ LLM_TENSOR_CHANNEL_MIX_VALUE,
LLM_TENSOR_ATTN_Q_A,
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_RWKV6,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
+ { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
+ { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
+ { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" },
+ { LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" },
+ { LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" },
+ { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" },
+ { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" },
+ { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" },
+ { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
+ { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
+ { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
+ { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" },
+ { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
+ { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
+ { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
+ { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" },
+ { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
+ { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
+ { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
+ { LLM_TENSOR_CHANNEL_MIX_LERP_R, "blk.%d.channel_mix_lerp_r" },
+ { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
+ { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
+ { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" },
+ },
+ },
{
LLM_ARCH_UNKNOWN,
{
MODEL_1B,
MODEL_1_3B,
MODEL_1_4B,
+ MODEL_1_6B,
MODEL_2B,
MODEL_2_8B,
MODEL_3B,
float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;
+ // for RWKV
+ uint32_t rescale_every_n_layers = 0;
+ uint32_t time_mix_extra_dim = 0;
+ uint32_t time_decay_extra_dim = 0;
+ uint32_t wkv_head_size = 0;
+
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_scale_train;
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
+ if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
+ if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true;
+ if (this->time_decay_extra_dim != other.time_decay_extra_dim) return true;
+ if (this->wkv_head_size != other.wkv_head_size) return true;
+
if (this->dec_start_token_id != other.dec_start_token_id) return true;
const float EPSILON = 1e-9f;
}
uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
- // corresponds to Mamba's conv_states size
- // TODO: maybe support other convolution strides than 1
- // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
- return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
+ // corresponds to Mamba's conv_states size or RWKV's token_shift states size
+ if (wkv_head_size != 0) {
+ // for RWKV models
+ return 2 * n_embd;
+ } else {
+ // TODO: maybe support other convolution strides than 1
+ // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
+ return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
+ }
}
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
- // corresponds to Mamba's ssm_states size
- return ssm_d_state * ssm_d_inner;
+ if (wkv_head_size != 0) {
+ // corresponds to RWKV's wkv_states size
+ return n_embd * wkv_head_size;
+ } else {
+ // corresponds to Mamba's ssm_states size
+ return ssm_d_state * ssm_d_inner;
+ }
}
};
struct ggml_tensor * ssm_conv1d_b;
struct ggml_tensor * ssm_dt_b;
+ // rwkv
+ struct ggml_tensor * time_mix_w1;
+ struct ggml_tensor * time_mix_w2;
+ struct ggml_tensor * time_mix_lerp_x;
+ struct ggml_tensor * time_mix_lerp_w;
+ struct ggml_tensor * time_mix_lerp_k;
+ struct ggml_tensor * time_mix_lerp_v;
+ struct ggml_tensor * time_mix_lerp_r;
+ struct ggml_tensor * time_mix_lerp_g;
+
+ struct ggml_tensor * time_mix_first;
+ struct ggml_tensor * time_mix_decay;
+ struct ggml_tensor * time_mix_decay_w1;
+ struct ggml_tensor * time_mix_decay_w2;
+ struct ggml_tensor * time_mix_key;
+ struct ggml_tensor * time_mix_value;
+ struct ggml_tensor * time_mix_receptance;
+ struct ggml_tensor * time_mix_gate;
+
+ struct ggml_tensor * time_mix_ln;
+ struct ggml_tensor * time_mix_ln_b;
+ struct ggml_tensor * time_mix_output;
+
+ struct ggml_tensor * channel_mix_lerp_k;
+ struct ggml_tensor * channel_mix_lerp_r;
+
+ struct ggml_tensor * channel_mix_key;
+ struct ggml_tensor * channel_mix_receptance;
+ struct ggml_tensor * channel_mix_value;
+
// long rope factors
struct ggml_tensor * rope_long = nullptr;
struct ggml_tensor * rope_short = nullptr;
const uint32_t n_seq_tokens = batch.n_seq_tokens;
if (cache.recurrent) {
- // For recurrent state architectures (like Mamba),
+ // For recurrent state architectures (like Mamba or RWKV),
// each cache cell can store the state for a whole sequence.
// A slot should be always be contiguous.
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
- // models like Mamba can't have a state partially erased
+ // models like Mamba or RWKV can't have a state partially erased
if (cache.recurrent) {
if (seq_id >= (int64_t) cache.size) {
// could be fatal
if (p0 == p1) return;
if (cache.recurrent) {
- // for Mamba-like models, only the pos needs to be shifted
+ // for Mamba-like or RWKV models, only the pos needs to be shifted
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
const int32_t tail_id = cache.cells[seq_id].tail;
if (tail_id >= 0) {
if (p0 == p1) return;
if (cache.recurrent) {
- // for Mamba-like models, only the pos needs to be changed
+ // for Mamba-like or RWKV models, only the pos needs to be changed
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
const int32_t tail_id = cache.cells[seq_id].tail;
if (tail_id >= 0) {
case MODEL_1B: return "1B";
case MODEL_1_3B: return "1.3B";
case MODEL_1_4B: return "1.4B";
+ case MODEL_1_6B: return "1.6B";
case MODEL_2B: return "2B";
case MODEL_2_8B: return "2.8B";
case MODEL_3B: return "3B";
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
case LLAMA_VOCAB_TYPE_UGM: return "UGM";
+ case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
default: return "unknown";
}
}
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_RWKV6:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+ ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
+ ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
+ ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
+ ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
+
+ switch (hparams.n_layer) {
+ case 24: model.type = e_model::MODEL_1_6B; break;
+ case 32:
+ switch (hparams.n_embd) {
+ case 2560: model.type = e_model::MODEL_3B; break;
+ case 4096: model.type = e_model::MODEL_7B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ case 61: model.type = e_model::MODEL_14B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
}
#endif
}
+ } else if (tokenizer_model == "rwkv") {
+ vocab.type = LLAMA_VOCAB_TYPE_RWKV;
+
+ // default special tokens
+ vocab.special_bos_id = -1;
+ vocab.special_eos_id = -1;
+ vocab.special_unk_id = -1;
+ vocab.special_sep_id = -1;
+ vocab.special_pad_id = -1;
} else {
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
}
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
vocab.tokenizer_add_bos = false;
vocab.tokenizer_add_eos = true;
+ } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+ vocab.tokenizer_add_space_prefix = false;
+ vocab.tokenizer_clean_spaces = false;
+ vocab.tokenizer_add_bos = false;
+ vocab.tokenizer_add_eos = false;
} else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
}
}
} else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
vocab.linefeed_id = vocab.special_pad_id;
+ } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
+ const std::vector<int> ids = llama_tokenize_internal(vocab, "\n", false);
+ GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+ vocab.linefeed_id = ids[0];
} else {
const std::vector<int> ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
}
} break;
+ case LLM_ARCH_RWKV6:
+ {
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // Block 0, LN0
+ model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
+ model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
+
+ // output
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+ model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
+
+ const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+ const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+ const int head_size = hparams.wkv_head_size;
+ const int attn_hidden_size = n_embd;
+ const int ffn_size = hparams.n_ff_arr[0];
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+ layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
+
+ layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd});
+ layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd});
+
+ layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5});
+ layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5});
+
+ layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1});
+ layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1});
+ layer.time_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1});
+ layer.time_mix_lerp_v = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1});
+ layer.time_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1});
+ layer.time_mix_lerp_g = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1});
+
+ layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size});
+ layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd});
+ layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim});
+ layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size});
+ layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd});
+ layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd});
+ layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd});
+ layer.time_mix_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd});
+
+ layer.time_mix_ln = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd});
+ layer.time_mix_ln_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd});
+ layer.time_mix_output = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size});
+
+ layer.channel_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1});
+ layer.channel_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1});
+
+ layer.channel_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size});
+ layer.channel_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd});
+ layer.channel_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd});
+ }
+
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
return cur;
}
+static struct ggml_tensor * llm_build_rwkv6_time_mix(
+ struct llama_context & lctx,
+ struct ggml_context * ctx,
+ const struct llama_layer * layer,
+ struct ggml_tensor * cur,
+ struct ggml_tensor * x_prev,
+ struct ggml_tensor ** wkv_state) {
+ size_t n_embed = cur->ne[0];
+ size_t n_seq_tokens = cur->ne[1];
+ size_t n_seqs = cur->ne[2];
+
+ size_t head_size = layer->time_mix_first->ne[0];
+ size_t head_count = layer->time_mix_first->ne[1];
+
+ size_t n_tokens = n_seqs * n_seq_tokens;
+
+ struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
+
+ sx = ggml_reshape_2d(ctx, sx, n_embed, n_tokens);
+ cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
+
+ struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
+
+ xxx = ggml_reshape_4d(
+ ctx,
+ ggml_tanh(
+ ctx,
+ ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
+ ),
+ layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
+ );
+
+ xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2));
+
+ xxx = ggml_mul_mat(
+ ctx,
+ ggml_reshape_4d(
+ ctx,
+ layer->time_mix_w2,
+ layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
+ ),
+ xxx
+ );
+
+ struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], 0);
+ struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * sizeof(float));
+ struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 2 * sizeof(float));
+ struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 3 * sizeof(float));
+ struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 4 * sizeof(float));
+
+ struct ggml_tensor * xw = ggml_add(
+ ctx,
+ ggml_mul(
+ ctx,
+ ggml_add(ctx, mw, layer->time_mix_lerp_w),
+ sx
+ ),
+ cur
+ );
+
+ struct ggml_tensor * xk = ggml_add(
+ ctx,
+ ggml_mul(
+ ctx,
+ ggml_add(ctx, mk, layer->time_mix_lerp_k),
+ sx
+ ),
+ cur
+ );
+
+ struct ggml_tensor * xv = ggml_add(
+ ctx,
+ ggml_mul(
+ ctx,
+ ggml_add(ctx, mv, layer->time_mix_lerp_v),
+ sx
+ ),
+ cur
+ );
+
+ struct ggml_tensor * xr = ggml_add(
+ ctx,
+ ggml_mul(
+ ctx,
+ ggml_add(ctx, mr, layer->time_mix_lerp_r),
+ sx
+ ),
+ cur
+ );
+
+ struct ggml_tensor * xg = ggml_add(
+ ctx,
+ ggml_mul(
+ ctx,
+ ggml_add(ctx, mg, layer->time_mix_lerp_g),
+ sx
+ ),
+ cur
+ );
+
+ struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
+ struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
+ struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
+ struct ggml_tensor * g = ggml_silu(
+ ctx,
+ llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
+ );
+
+ struct ggml_tensor * w = ggml_mul_mat(
+ ctx,
+ layer->time_mix_decay_w2,
+ ggml_tanh(
+ ctx,
+ ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
+ )
+ );
+
+ w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed));
+ w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
+ w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
+
+ k = ggml_transpose(ctx, k);
+ v = ggml_transpose(ctx, v);
+ r = ggml_transpose(ctx, r);
+
+ struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+ cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
+ *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float));
+
+ // group norm with head_count groups
+ cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens);
+ cur = ggml_norm(ctx, cur, 64e-5f);
+
+ // Convert back to regular vectors.
+ cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
+ cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+
+ cur = ggml_mul(ctx, cur, g);
+ cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
+
+ return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs);
+}
+
+static struct ggml_tensor * llm_build_rwkv6_channel_mix(
+ struct llama_context & lctx,
+ struct ggml_context * ctx,
+ const struct llama_layer * layer,
+ struct ggml_tensor * cur,
+ struct ggml_tensor * x_prev) {
+ struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
+ struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
+ struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
+
+ struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr));
+ struct ggml_tensor * k = ggml_sqr(
+ ctx,
+ ggml_relu(
+ ctx,
+ llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
+ )
+ );
+
+ return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
+}
+
struct llm_build_context {
const llama_model & model;
llama_context & lctx;
return gf;
}
+
+ ggml_cgraph * build_rwkv6() {
+ ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+ // Token shift state dimensions should be 2 * n_emb
+ GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
+
+ const int64_t n_seqs = batch.n_seqs;
+ const int64_t n_seq_tokens = batch.n_seq_tokens;
+ const int64_t n_tokens = batch.n_tokens;
+ GGML_ASSERT(n_seqs != 0);
+ GGML_ASSERT(batch.equal_seqs);
+ GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+ struct ggml_tensor * state_copy = build_inp_s_copy();
+ struct ggml_tensor * state_mask = build_inp_s_mask();
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+ inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
+
+ for (int il = 0; il < n_layer; ++il) {
+ const llama_layer * layer = &model.layers[il];
+
+ // (ab)using the KV cache to store the states
+ struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
+ gf, kv_self.k_l[il], state_copy, state_mask,
+ hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
+ struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
+ gf, kv_self.v_l[il], state_copy, state_mask,
+ hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
+
+ cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+ token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs);
+
+ struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
+ struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
+
+ struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il);
+ struct ggml_tensor * x_prev = ggml_concat(
+ ctx0,
+ att_shift,
+ ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
+ 1
+ );
+
+ cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
+ ggml_build_forward_expand(gf, cur);
+ ggml_build_forward_expand(
+ gf,
+ ggml_cpy(
+ ctx0,
+ wkv_states,
+ ggml_view_1d(
+ ctx0,
+ kv_self.v_l[il],
+ hparams.n_embd_v_s() * n_seqs,
+ hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+ )
+ )
+ );
+
+ struct ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il);
+ x_prev = ggml_concat(
+ ctx0,
+ ffn_shift,
+ ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
+ 1
+ );
+ cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
+ ggml_build_forward_expand(gf, cur);
+
+ struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
+ struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn));
+
+ token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1);
+
+ ggml_build_forward_expand(
+ gf,
+ ggml_cpy(
+ ctx0,
+ ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
+ ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
+ )
+ );
+
+ if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
+ cur = ggml_scale(ctx0, cur, 0.5F);
+ }
+
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+ cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+ cb(cur, "result_output", -1);
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
{
result = llm.build_exaone();
} break;
+ case LLM_ARCH_RWKV6:
+ {
+ result = llm.build_rwkv6();
+ } break;
default:
GGML_ABORT("fatal error");
}
// NOTE: can't use LLM_TN here because the layer number is not known
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
+ // do not quantize RWKV's time_mix_first tensors
+ quantize &= name.find("time_mix_first.weight") == std::string::npos;
+ quantize &= name.find("time_mix_w1.weight") == std::string::npos;
+ quantize &= name.find("time_mix_w2.weight") == std::string::npos;
+
// do not quantize relative position bias (T5)
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER:
case LLM_ARCH_JAIS:
+ case LLM_ARCH_RWKV6:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values
bool llama_model_is_recurrent(const struct llama_model * model) {
switch (model->arch) {
case LLM_ARCH_MAMBA: return true;
+ case LLM_ARCH_RWKV6: return true;
default: return false;
}
}