MODEL_17M,
MODEL_22M,
MODEL_33M,
+ MODEL_60M,
MODEL_70M,
+ MODEL_80M,
MODEL_109M,
MODEL_137M,
MODEL_160M,
+ MODEL_220M,
+ MODEL_250M,
MODEL_335M,
MODEL_410M,
+ MODEL_770M,
+ MODEL_780M,
MODEL_0_5B,
MODEL_1B,
MODEL_1_3B,
MODEL_6_9B,
MODEL_7B,
MODEL_8B,
+ MODEL_11B,
MODEL_12B,
MODEL_13B,
MODEL_14B,
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
uint32_t n_vocab_type = 0; // for BERT-style token types
+ uint32_t n_rel_attn_bkts = 0;
uint32_t n_layer_dense_lead = 0;
uint32_t n_lora_q = 0;
bool use_alibi = false;
bool attn_soft_cap = false;
+ // needed by encoder-decoder models (e.g. T5, FLAN-T5)
+ // ref: https://github.com/ggerganov/llama.cpp/pull/8141
+ llama_token dec_start_token_id = -1;
+
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
if (this->n_expert != other.n_expert) return true;
if (this->n_expert_used != other.n_expert_used) return true;
+ if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true;
if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
if (this->n_lora_q != other.n_lora_q) return true;
if (this->n_lora_kv != other.n_lora_kv) return true;
if (this->ssm_d_state != other.ssm_d_state) return true;
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
+ if (this->dec_start_token_id != other.dec_start_token_id) return true;
+
const float EPSILON = 1e-9f;
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
void * cb_eval_user_data;
};
+// TODO: separate into "llama_layer_enc" and "llama_layer_dec"
struct llama_layer {
// normalization
struct ggml_tensor * attn_norm;
struct ggml_tensor * attn_sub_norm;
struct ggml_tensor * attn_post_norm;
struct ggml_tensor * ffn_sub_norm;
+ struct ggml_tensor * attn_norm_cross;
+ struct ggml_tensor * attn_norm_enc;
// attention
struct ggml_tensor * wq;
struct ggml_tensor * wq_b;
struct ggml_tensor * wkv_a_mqa;
struct ggml_tensor * wkv_b;
+ struct ggml_tensor * wq_cross;
+ struct ggml_tensor * wk_cross;
+ struct ggml_tensor * wv_cross;
+ struct ggml_tensor * wo_cross;
+ struct ggml_tensor * wq_enc;
+ struct ggml_tensor * wk_enc;
+ struct ggml_tensor * wv_enc;
+ struct ggml_tensor * wo_enc;
// attention bias
struct ggml_tensor * bq;
struct ggml_tensor * bo;
struct ggml_tensor * bqkv;
+ // relative position bias
+ struct ggml_tensor * attn_rel_b;
+ struct ggml_tensor * attn_rel_b_enc;
+ struct ggml_tensor * attn_rel_b_cross;
+
// normalization
struct ggml_tensor * ffn_norm;
struct ggml_tensor * ffn_norm_b;
struct ggml_tensor * layer_out_norm;
struct ggml_tensor * layer_out_norm_b;
struct ggml_tensor * ffn_norm_exps;
+ struct ggml_tensor * ffn_norm_enc;
// ff
struct ggml_tensor * ffn_gate; // w1
struct ggml_tensor * ffn_down; // w2
struct ggml_tensor * ffn_up; // w3
+ struct ggml_tensor * ffn_gate_enc;
+ struct ggml_tensor * ffn_down_enc;
+ struct ggml_tensor * ffn_up_enc;
// ff MoE
struct ggml_tensor * ffn_gate_inp;
struct ggml_tensor * output_norm_b;
struct ggml_tensor * output;
struct ggml_tensor * output_b;
+ struct ggml_tensor * output_norm_enc;
std::vector<llama_layer> layers;
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
+ // whether we are computing encoder output or decoder output
+ bool is_encoding = false;
+
+ // output of the encoder part of the encoder-decoder models
+ std::vector<float> embd_enc;
+ std::vector<std::set<llama_seq_id>> seq_ids_enc;
+
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_t sched = nullptr;
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
+ struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
+ struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
+ struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
// control vectors
struct llama_control_vector cvec;
case MODEL_17M: return "17M";
case MODEL_22M: return "22M";
case MODEL_33M: return "33M";
+ case MODEL_60M: return "60M";
case MODEL_70M: return "70M";
+ case MODEL_80M: return "80M";
case MODEL_109M: return "109M";
case MODEL_137M: return "137M";
case MODEL_160M: return "160M";
+ case MODEL_220M: return "220M";
+ case MODEL_250M: return "250M";
case MODEL_335M: return "335M";
case MODEL_410M: return "410M";
+ case MODEL_770M: return "770M";
+ case MODEL_780M: return "780M";
case MODEL_0_5B: return "0.5B";
case MODEL_1B: return "1B";
case MODEL_1_3B: return "1.3B";
case MODEL_6_9B: return "6.9B";
case MODEL_7B: return "7B";
case MODEL_8B: return "8B";
+ case MODEL_11B: return "11B";
case MODEL_12B: return "12B";
case MODEL_13B: return "13B";
case MODEL_14B: return "14B";
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_T5:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
+
+ uint32_t dec_start_token_id;
+ if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) {
+ hparams.dec_start_token_id = dec_start_token_id;
+ }
+
+ switch (hparams.n_layer) {
+ case 6: model.type = e_model::MODEL_60M; break; // t5-small
+ case 8: model.type = e_model::MODEL_80M; break; // flan-t5-small
+ case 12:
+ switch (hparams.n_ff) {
+ case 3072: model.type = e_model::MODEL_220M; break; // t5-base
+ case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ case 24:
+ switch (hparams.n_ff) {
+ case 4096: model.type = e_model::MODEL_770M; break; // t5-large
+ case 2816: model.type = e_model::MODEL_780M; break; // flan-t5-large
+ case 16384: model.type = e_model::MODEL_3B; break; // t5-3b
+ case 5120: model.type = e_model::MODEL_3B; break; // flan-t5-xl
+ case 65536: model.type = e_model::MODEL_11B; break; // t5-11b
+ case 10240: model.type = e_model::MODEL_11B; break; // flan-t5-xxl
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_JAIS:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1});
}
} break;
+ case LLM_ARCH_T5:
+ {
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd});
+
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ // if output is NULL, init from the input tok embed
+ if (model.output == NULL) {
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+ }
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd});
+ layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+ layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
+ layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
+ layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
+ layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+ layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
+ layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd});
+ layer.ffn_up_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff});
+
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd});
+ layer.attn_rel_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+ layer.attn_norm_cross = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd});
+ // this tensor seems to be unused in HF transformers implementation
+ layer.attn_rel_b_cross = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+ layer.wq_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
+ layer.wk_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
+ layer.wv_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
+ layer.wo_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd});
+ layer.ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd});
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff});
+ }
+ } break;
case LLM_ARCH_JAIS:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
const int32_t n_tokens;
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
const int32_t n_outputs;
+ const int32_t n_outputs_enc;
const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_ctx_orig;
n_tokens (batch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n),
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
+ n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
+ lctx.inp_pos_bucket = nullptr;
+ lctx.inp_embd_enc = nullptr;
+ lctx.inp_KQ_mask_cross = nullptr;
}
void free() {
return gf;
}
+ struct ggml_tensor * llm_build_pos_bucket(bool causal) {
+ if (causal) {
+ lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
+ } else {
+ lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
+ }
+
+ ggml_set_input(lctx.inp_pos_bucket);
+ cb(lctx.inp_pos_bucket, "pos_bucket", -1);
+
+ return lctx.inp_pos_bucket;
+ }
+
+ struct ggml_tensor * llm_build_pos_bias(struct ggml_tensor * pos_bucket, struct ggml_tensor * attn_rel_b) {
+ struct ggml_tensor * pos_bucket_1d = ggml_view_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0);
+ cb(pos_bucket_1d, "pos_bucket_1d", -1);
+
+ struct ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
+ cb(pos_bias, "pos_bias", -1);
+
+ pos_bias = ggml_view_3d(ctx0, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0], 0);
+ cb(pos_bias, "pos_bias", -1);
+
+ pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3);
+ cb(pos_bias, "pos_bias", -1);
+
+ pos_bias = ggml_cont(ctx0, pos_bias);
+ cb(pos_bias, "pos_bias", -1);
+
+ return pos_bias;
+ }
+
+ struct ggml_tensor * llm_build_inp_embd_enc() {
+ const int64_t n_embd = hparams.n_embd;
+ lctx.inp_embd_enc = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc);
+ ggml_set_input(lctx.inp_embd_enc);
+ cb(lctx.inp_embd_enc, "embd_enc", -1);
+ return lctx.inp_embd_enc;
+ }
+
+ struct ggml_tensor * llm_build_inp_KQ_mask_cross() {
+ lctx.inp_KQ_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+ ggml_set_input(lctx.inp_KQ_mask_cross);
+ cb(lctx.inp_KQ_mask_cross, "KQ_mask_cross", -1);
+ return lctx.inp_KQ_mask_cross;
+ }
+
struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
return gf;
}
+ struct ggml_cgraph * build_t5() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+ if (lctx.is_encoding) {
+ struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm_enc, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq_enc, cur);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk_enc, cur);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_enc, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+ struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+ struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+ cb(kq, "kq", il);
+
+ struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
+ struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
+ struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
+ cb(kq_b, "kq_b", il);
+
+ kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+
+ struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
+ cb(v, "v", il);
+
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
+ cb(kqv, "kqv", il);
+
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
+
+ cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+ cb(cur, "kqv_merged_cont", il);
+
+ ggml_build_forward_expand(gf, cur);
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].wo_enc, cur);
+ cb(cur, "kqv_out", il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm_enc, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ // T5 uses relu, flan-T5 uses gelu-gated
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up_enc, NULL, NULL,
+ model.layers[il].ffn_gate_enc, NULL, NULL,
+ model.layers[il].ffn_down_enc, NULL, NULL,
+ NULL,
+ model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+ model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
+ cb, il);
+ cb(cur, "ffn_out", il);
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "ffn_out", il);
+
+ ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+ if (layer_dir != nullptr) {
+ cur = ggml_add(ctx0, cur, layer_dir);
+ }
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+ cb(cur, "result_embd", -1);
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm_enc, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+ } else {
+ struct ggml_tensor * embd_enc = llm_build_inp_embd_enc();
+ struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
+
+ struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask();
+ struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+
+ struct ggml_tensor * k =
+ ggml_view_3d(ctx0, kv_self.k_l[il],
+ n_embd_head_k, n_kv, n_head_kv,
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+ 0);
+ cb(k, "k", il);
+
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx0, kv_self.v_l[il],
+ n_kv, n_embd_head_v, n_head_kv,
+ ggml_element_size(kv_self.v_l[il])*n_ctx,
+ ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
+ 0);
+ cb(v, "v", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+ struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+ cb(kq, "kq", il);
+
+ struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
+ struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
+ struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
+ cb(kq_b, "kq_b", il);
+
+ kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+ cb(kqv, "kqv", il);
+
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
+
+ cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+ cb(cur, "kqv_merged_cont", il);
+
+ ggml_build_forward_expand(gf, cur);
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+ cb(cur, "kqv_out", il);
+ }
+
+ cur = ggml_add(ctx0, cur, inpSA);
+ cb(cur, "cross_inp", il);
+
+ struct ggml_tensor * inpCA = cur;
+
+ // norm
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].attn_norm_cross, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm_cross", il);
+
+ // cross-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq_cross, cur);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk_cross, embd_enc);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_cross, embd_enc);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
+
+ struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+ struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+ cb(kq, "kq", il);
+
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+
+ struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
+ cb(v, "v", il);
+
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
+ cb(kqv, "kqv", il);
+
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
+
+ cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+ cb(cur, "kqv_merged_cont", il);
+
+ ggml_build_forward_expand(gf, cur);
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].wo_cross, cur);
+ cb(cur, "kqv_out", il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ // T5 uses relu, flan-T5 uses gelu-gated
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+ model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
+ cb, il);
+ cb(cur, "ffn_out", il);
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "ffn_out", il);
+
+ ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+ if (layer_dir != nullptr) {
+ cur = ggml_add(ctx0, cur, layer_dir);
+ }
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+ cb(cur, "result_embd", -1);
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+ }
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
struct ggml_cgraph * build_jais() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
{
result = llm.build_bitnet();
} break;
+ case LLM_ARCH_T5:
+ {
+ result = llm.build_t5();
+ } break;
case LLM_ARCH_JAIS:
{
result = llm.build_jais();
}
}
+static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
+ // TODO move to hparams if a T5 variant appears that uses a different value
+ const int64_t max_distance = 128;
+
+ if (bidirectional) {
+ n_buckets >>= 1;
+ }
+
+ const int64_t max_exact = n_buckets >> 1;
+
+ int32_t relative_position = x - y;
+ int32_t relative_bucket = 0;
+ if (bidirectional) {
+ relative_bucket += (relative_position > 0) * n_buckets;
+ relative_position = abs(relative_position);
+ } else {
+ relative_position = -std::min<int32_t>(relative_position, 0);
+ }
+ int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
+ relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
+ return relative_bucket;
+}
+
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
//
// set input data
if (lctx.inp_KQ_mask) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
- if (cparams.causal_attn) {
+ if (cparams.causal_attn && !lctx.is_encoding) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;
} else {
// when using kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
- const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
+ const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
}
}
}
+
+ if (lctx.inp_pos_bucket) {
+ const int64_t n_tokens = batch.n_tokens;
+
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
+
+ int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
+
+ if (!lctx.is_encoding) {
+ const int64_t n_kv = kv_self.n;
+ for (int h = 0; h < 1; ++h) {
+ for (int j = 0; j < n_tokens; ++j) {
+ for (int i = 0; i < n_kv; ++i) {
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
+ }
+ }
+ }
+ } else {
+ for (int h = 0; h < 1; ++h) {
+ for (int j = 0; j < n_tokens; ++j) {
+ for (int i = 0; i < n_tokens; ++i) {
+ data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
+ }
+ }
+ }
+ }
+ }
+
+ if (!lctx.is_encoding && lctx.inp_embd_enc) {
+ assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
+ assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
+
+ ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
+ }
+
+ if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
+ const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
+ const int64_t n_tokens = batch.n_tokens;
+
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
+
+ float * data = (float *) lctx.inp_KQ_mask_cross->data;
+
+ for (int h = 0; h < 1; ++h) {
+ for (int j = 0; j < n_tokens; ++j) {
+ for (int i = 0; i < n_output_enc; ++i) {
+ float f = -INFINITY;
+ for (int s = 0; s < batch.n_seq_id[j]; ++s) {
+ const llama_seq_id seq_id = batch.seq_id[j][s];
+ if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
+ f = 0.0f;
+ }
+ }
+ data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
+ }
+ }
+
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+ for (int j = 0; j < n_output_enc; ++j) {
+ data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
+ }
+ }
+ }
+ }
}
// Make sure enough space is available for outputs.
// TODO: use a per-batch flag for logits presence instead
const bool has_logits = !cparams.embeddings;
- const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+ const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch
+ lctx.is_encoding = false;
const uint32_t n_tokens_all = batch_all.n_tokens;
if (n_tokens_all == 0) {
const auto n_ubatch = cparams.n_ubatch;
+ // TODO: simplify or deprecate
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id_arr;
return 0;
}
+// encode a batch of tokens by evaluating the encoder part of the transformer
+//
+// - lctx: llama context
+// - batch: batch to evaluate
+//
+// return 0 on success
+// return positive int on warning
+// return negative int on error
+//
+static int llama_encode_internal(
+ llama_context & lctx,
+ llama_batch batch) {
+
+ lctx.is_encoding = true;
+
+ const uint32_t n_tokens = batch.n_tokens;
+
+ if (n_tokens == 0) {
+ LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
+ return -1;
+ }
+
+ const auto & model = lctx.model;
+ const auto & hparams = model.hparams;
+ const auto & cparams = lctx.cparams;
+
+ GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+
+ // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
+ GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
+
+ if (lctx.t_compute_start_us == 0) {
+ lctx.t_compute_start_us = ggml_time_us();
+ }
+
+ lctx.n_queued_tokens += n_tokens;
+
+ const int64_t n_embd = hparams.n_embd;
+
+ // TODO: simplify or deprecate
+ std::vector<llama_pos> pos;
+ std::vector<int32_t> n_seq_id;
+ std::vector<llama_seq_id *> seq_id_arr;
+ std::vector<std::vector<llama_seq_id>> seq_id;
+
+ // reserve output buffer
+ if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
+ return -2;
+ };
+
+ for (uint32_t i = 0; i < n_tokens; ++i) {
+ lctx.output_ids[i] = i;
+ }
+
+ lctx.inp_embd_enc = NULL;
+ lctx.n_outputs = n_tokens;
+
+ const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+ GGML_ASSERT(n_threads > 0);
+
+ // helpers for smoother batch API transition
+ // after deprecating the llama_eval calls, these will be removed
+ if (batch.pos == nullptr) {
+ pos.resize(n_tokens);
+ for (uint32_t i = 0; i < n_tokens; i++) {
+ pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
+ }
+
+ batch.pos = pos.data();
+ }
+
+ if (batch.seq_id == nullptr) {
+ n_seq_id.resize(n_tokens);
+ seq_id.resize(n_tokens);
+ seq_id_arr.resize(n_tokens);
+ for (uint32_t i = 0; i < n_tokens; i++) {
+ n_seq_id[i] = 1;
+ seq_id[i].resize(1);
+ seq_id[i][0] = batch.all_seq_id;
+ seq_id_arr[i] = seq_id[i].data();
+ }
+
+ batch.n_seq_id = n_seq_id.data();
+ batch.seq_id = seq_id_arr.data();
+ }
+
+ ggml_backend_sched_reset(lctx.sched);
+ ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
+
+ ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
+
+ // the output embeddings after the final encoder normalization
+ struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1];
+
+ GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
+
+ ggml_backend_sched_alloc_graph(lctx.sched, gf);
+
+ llama_set_inputs(lctx, batch);
+
+ llama_graph_compute(lctx, gf, n_threads);
+
+ // extract embeddings
+ if (embd) {
+ ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
+ GGML_ASSERT(backend_embd != nullptr);
+
+ // extract token embeddings
+ GGML_ASSERT(lctx.embd != nullptr);
+
+ lctx.embd_enc.resize(n_tokens*n_embd);
+ float * embd_out = lctx.embd_enc.data();
+
+ ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
+
+ // remember the sequence ids used during the encoding - needed for cross attention later
+ lctx.seq_ids_enc.resize(n_tokens);
+ for (uint32_t i = 0; i < n_tokens; i++) {
+ for (int s = 0; s < batch.n_seq_id[i]; s++) {
+ llama_seq_id seq_id = batch.seq_id[i][s];
+ lctx.seq_ids_enc[i].insert(seq_id);
+ }
+ }
+ }
+
+ // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
+ // overlap with device computation.
+ ggml_backend_sched_reset(lctx.sched);
+
+ return 0;
+}
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
std::string normalized;
normalize(text, &normalized);
size_t input_len = normalized.size();
+ if (input_len == 0) {
+ return;
+ }
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
- std::vector<struct best_tokenization> tokenization_results(input_len + 1, {0, 0, -FLT_MAX});
+ std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX});
// at the beginning tokenization score is zero
- tokenization_results[0] = { 0, 0, 0 };
+ tokenization_results[0] = { vocab.special_unk_id, 0, 0 };
for (size_t input_offset = 0; input_offset < input_len;) {
size_t prefix_offset = input_offset;
single_codepoint_token_found = true;
}
llama_token token_id = node->value;
- const auto &token_data = vocab.id_to_token[token_id];
+ const auto & token_data = vocab.id_to_token[token_id];
// we set the user-defined token scores to 0 to make them more likely to be selected
// (normal token scores are log probabilities, so they are negative)
// sanity checks
//
- // - qs.n_attention_wv == 0 for Mamba models
- // - qs.n_attention_wv == model.hparams.n_layer for Transformer models
+ // - qs.n_attention_wv == 0 for Mamba models
+ // - qs.n_attention_wv == model.hparams.n_layer for Transformer models
+ // - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
//
- GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
+ GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
size_t total_size_org = 0;
size_t total_size_new = 0;
quantize &= name.find("ssm_x.weight") == std::string::npos;
quantize &= name.find("ssm_dt.weight") == std::string::npos;
+ // do not quantize relative position bias (T5)
+ quantize &= name.find("attn_rel_b.weight") == std::string::npos;
+
enum ggml_type new_type;
void * new_data;
size_t new_size;
return it->second;
}
+bool llama_model_has_encoder(const struct llama_model * model) {
+ switch (model->arch) {
+ case LLM_ARCH_T5: return true;
+ default: return false;
+ }
+}
+
+llama_token llama_model_decoder_start_token(const struct llama_model * model) {
+ return model->hparams.dec_start_token_id;
+}
+
uint32_t llama_model_quantize(
const char * fname_inp,
const char * fname_out,
if (batch.logits) free(batch.logits);
}
+int32_t llama_encode(
+ struct llama_context * ctx,
+ struct llama_batch batch) {
+ const int ret = llama_encode_internal(*ctx, batch);
+ if (ret < 0) {
+ LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
+ }
+
+ return ret;
+}
+
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {