LLM_ARCH_STARCODER,
LLM_ARCH_PERSIMMON,
LLM_ARCH_REFACT,
+ LLM_ARCH_BERT,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
LLM_ARCH_QWEN,
{ LLM_ARCH_STARCODER, "starcoder" },
{ LLM_ARCH_PERSIMMON, "persimmon" },
{ LLM_ARCH_REFACT, "refact" },
+ { LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_QWEN, "qwen" },
LLM_KV_ATTENTION_VALUE_LENGTH,
LLM_KV_ATTENTION_LAYERNORM_EPS,
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
+ LLM_KV_ATTENTION_CAUSAL,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_LIST,
LLM_KV_TOKENIZER_TOKEN_TYPE,
+ LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
LLM_KV_TOKENIZER_SCORES,
LLM_KV_TOKENIZER_MERGES,
LLM_KV_TOKENIZER_BOS_ID,
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
+ { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
+ { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" },
{ LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" },
{ LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" },
{ LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" },
enum llm_tensor {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_TOKEN_EMBD_NORM,
+ LLM_TENSOR_TOKEN_TYPES,
LLM_TENSOR_POS_EMBD,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_NORM,
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_BERT,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
+ { LLM_TENSOR_TOKEN_TYPES, "token_types" },
+ { LLM_TENSOR_POS_EMBD, "position_embd" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_output_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.layer_output_norm" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_BLOOM,
{
// available llama models
enum e_model {
MODEL_UNKNOWN,
+ MODEL_17M,
+ MODEL_22M,
+ MODEL_33M,
+ MODEL_109M,
+ MODEL_335M,
MODEL_0_5B,
MODEL_1B,
MODEL_2B,
uint32_t n_ff;
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
+ uint32_t n_vocab_type = 0; // for BERT-style token types
float f_norm_eps;
float f_norm_rms_eps;
float f_clamp_kqv;
float f_max_alibi_bias;
+ bool causal_attn = true;
+
bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true;
llama_vocab vocab;
struct ggml_tensor * tok_embd;
+ struct ggml_tensor * type_embd;
struct ggml_tensor * pos_embd;
struct ggml_tensor * tok_norm;
struct ggml_tensor * tok_norm_b;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_t sched = nullptr;
- // allocator for the input tensors
- ggml_tallocr * alloc = nullptr;
// input tensors
ggml_backend_buffer_t buf_input = nullptr;
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
+ struct ggml_tensor * inp_sum; // F32 [1, n_batch]
#ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL;
switch (type) {
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
+ case LLAMA_VOCAB_TYPE_WPM: return "WPM";
default: return "unknown";
}
}
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_BERT:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+ ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
+ ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
+
+ switch (hparams.n_layer) {
+ case 3:
+ model.type = e_model::MODEL_17M; break; // bge-micro
+ case 6:
+ model.type = e_model::MODEL_22M; break; // MiniLM-L6
+ case 12:
+ switch (hparams.n_embd) {
+ case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small
+ case 768: model.type = e_model::MODEL_109M; break; // bge-base
+ } break;
+ case 24:
+ model.type = e_model::MODEL_335M; break; // bge-large
+ }
+ } break;
case LLM_ARCH_BLOOM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
vocab.special_unk_id = -1;
vocab.special_sep_id = -1;
vocab.special_pad_id = -1;
+ } else if (tokenizer_name == "bert") {
+ vocab.type = LLAMA_VOCAB_TYPE_WPM;
+
+ // default special tokens
+ vocab.special_bos_id = 101;
+ vocab.special_eos_id = 102;
+ vocab.special_unk_id = 100;
+ vocab.special_sep_id = -1;
+ vocab.special_pad_id = -1;
+ vocab.add_space_prefix = false;
} else {
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
+ } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
+ vocab.linefeed_id = vocab.special_pad_id;
} else {
const std::vector<int> ids = llama_tokenize_internal(vocab, "\u010A", false);
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const int64_t n_embd_gqa = n_embd_v_gqa;
const int64_t n_vocab = hparams.n_vocab;
+ const int64_t n_vocab_type = hparams.n_vocab_type;
const int64_t n_ff = hparams.n_ff;
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64});
}
} break;
- case LLM_ARCH_BLOOM:
+ case LLM_ARCH_BERT:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
- model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
- model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
+ model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
+ model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train});
+ model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
+ model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+ layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
+
+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+ layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
+
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+ layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
+
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+ layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
+
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+ layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
+
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+ layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
+
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+ layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
+
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+ layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
+ }
+ } break;
+ case LLM_ARCH_BLOOM:
+ {
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+ model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
+ model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
// output
{
const int32_t n_orig_ctx;
const bool do_rope_shift;
+ const bool causal_attn;
const llm_build_cb & cb;
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
do_rope_shift (worst_case || kv_self.has_shift),
+ causal_attn (hparams.causal_attn),
cb (cb),
buf_compute_meta (lctx.buf_compute_meta) {
// all initializations should be done in init()
return gf;
}
+ struct ggml_cgraph * build_bert() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ // get input vectors with right size
+ struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
+ struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0);
+
+ // construct input embeddings (token, type, position)
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
+ // token types are hardcoded to zero ("Sentence A")
+ struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
+ inpL = ggml_add(ctx0, inpL, type_row0);
+ inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
+ cb(inpL, "inp_embd", -1);
+
+ // embed layer norm
+ inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
+ cb(inpL, "inp_norm", -1);
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
+ cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens]
+
+ // iterate layers
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * cur = inpL;
+
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+
+ // seems like we just need to do this for Q?
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+ model.layers[il].wo, model.layers[il].bo,
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ cb(cur, "kqv_out", il);
+ }
+
+ // re-add the layer input
+ cur = ggml_add(ctx0, cur, inpL);
+
+ // attention layer norm
+ cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, cb, il);
+
+ struct ggml_tensor * ffn_inp = cur;
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b,
+ NULL, NULL,
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+ NULL,
+ LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+ cb(cur, "ffn_out", il);
+
+ // attentions bypass the intermediate layer
+ cur = ggml_add(ctx0, cur, ffn_inp);
+
+ // output layer norm
+ cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ // final output
+ cur = inpL;
+
+ // pooling
+ cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
+ cb(cur, "result_embed", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
struct ggml_cgraph * build_bloom() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
- const llama_batch & batch) {
+ const llama_batch & batch,
+ bool worst_case) {
const auto & model = lctx.model;
- // check if we should build the worst-case graph (for memory measurement)
- const bool worst_case = ggml_tallocr_is_measure(lctx.alloc);
-
// this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
if (il >= 0) {
struct llm_build_context llm(lctx, batch, cb, worst_case);
- //
- // set input data
- //
-
- if (!ggml_tallocr_is_measure(lctx.alloc)) {
- if (batch.token) {
- const int64_t n_tokens = batch.n_tokens;
-
- ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
- }
-
- if (batch.embd) {
- const int64_t n_embd = llm.n_embd;
- const int64_t n_tokens = batch.n_tokens;
-
- ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
- }
-
- if (batch.pos) {
- const int64_t n_tokens = batch.n_tokens;
-
- ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
- }
-
- {
- const int64_t n_kv = llm.n_kv;
- const int64_t n_tokens = batch.n_tokens;
-
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
- float * data = (float *) lctx.inp_KQ_mask->data;
-
- for (int h = 0; h < 1; ++h) {
- for (int j = 0; j < n_tokens; ++j) {
- const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j][0];
-
- for (int i = 0; i < n_kv; ++i) {
- float f;
- if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
- f = -INFINITY;
- } else {
- f = 0;
- }
- data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
- }
- }
- }
- }
-
- if (llm.do_rope_shift) {
- const int64_t n_ctx = llm.n_ctx;
-
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
- int32_t * data = (int32_t *) lctx.inp_K_shift->data;
-
- for (int i = 0; i < n_ctx; ++i) {
- data[i] = lctx.kv_self.cells[i].delta;
- }
- }
- }
-
llm.init();
switch (model.arch) {
{
result = llm.build_refact();
} break;
+ case LLM_ARCH_BERT:
+ {
+ result = llm.build_bert();
+ } break;
case LLM_ARCH_BLOOM:
{
result = llm.build_bloom();
return result;
}
+static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
+ //
+ // set input data
+ //
+
+ const auto & hparams = lctx.model.hparams;
+ const auto & cparams = lctx.cparams;
+ const auto & kv_self = lctx.kv_self;
+
+ if (batch.token) {
+ const int64_t n_tokens = batch.n_tokens;
+
+ ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
+ }
+
+ if (batch.embd) {
+ const int64_t n_embd = hparams.n_embd;
+ const int64_t n_tokens = batch.n_tokens;
+
+ ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
+ }
+
+ if (batch.pos) {
+ const int64_t n_tokens = batch.n_tokens;
+
+ ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
+ }
+
+ {
+ const int64_t n_kv = kv_self.n;
+ const int64_t n_tokens = batch.n_tokens;
+
+ assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
+
+ float * data = (float *) lctx.inp_KQ_mask->data;
+
+ for (int h = 0; h < 1; ++h) {
+ for (int j = 0; j < n_tokens; ++j) {
+ const llama_pos pos = batch.pos[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
+
+ for (int i = 0; i < n_kv; ++i) {
+ float f;
+ if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
+ f = -INFINITY;
+ } else {
+ f = 0;
+ }
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+ }
+ }
+ }
+ }
+
+
+ {
+ assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
+ float * data = (float *) lctx.inp_sum->data;
+
+ for (int i = 0; i < batch.n_tokens; ++i) {
+ data[i] = 1.0f/float(batch.n_tokens);
+ }
+ }
+
+ if (kv_self.has_shift) {
+ const int64_t n_ctx = cparams.n_ctx;
+
+ assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
+
+ int32_t * data = (int32_t *) lctx.inp_K_shift->data;
+
+ for (int i = 0; i < n_ctx; ++i) {
+ data[i] = lctx.kv_self.cells[i].delta;
+ }
+ }
+}
+
// decode a batch of tokens by evaluating the transformer
//
// - lctx: llama context
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
- ggml_cgraph * gf = llama_build_graph(lctx, batch);
+ ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
- GGML_ASSERT(strcmp(res->name, "result_output") == 0);
-
- // the embeddings could be the second to last tensor, or the third to last tensor
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
- if (strcmp(embeddings->name, "result_norm") != 0) {
- embeddings = gf->nodes[gf->n_nodes - 3];
- GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
+ if (strcmp(res->name, "result_output") == 0) {
+ // the embeddings could be the second to last tensor, or the third to last tensor
+ if (strcmp(embeddings->name, "result_norm") != 0) {
+ embeddings = gf->nodes[gf->n_nodes - 3];
+ GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
+ }
+ } else if (strcmp(res->name, "result_embed") == 0) {
+ embeddings = res;
+ res = nullptr;
+ } else {
+ GGML_ASSERT(false);
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
if (lctx.backend_cpu != nullptr) {
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
}
+
+ llama_set_inputs(lctx, batch);
+
ggml_backend_sched_graph_compute(lctx.sched, gf);
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
// extract logits
// TODO: do not compute and extract logits if only embeddings are needed
// need to update the graphs to skip "result_output"
- {
+ if (res) {
auto & logits_out = lctx.logits;
#ifndef NDEBUG
if (!lctx.embedding.empty()) {
auto & embedding_out = lctx.embedding;
+ const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
+
embedding_out.resize(n_embd);
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
- ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float));
+ ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float));
ggml_backend_synchronize(embeddings_backend);
}
GGML_ASSERT(false);
return unicode_to_bytes_bpe(token_data.text);
}
+ case LLAMA_VOCAB_TYPE_WPM: {
+ GGML_ASSERT(false);
+ }
default:
GGML_ASSERT(false);
}
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
return vocab.token_to_id.at(buf);
}
+ case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_BPE: {
return vocab.token_to_id.at(bytes_to_unicode_bpe(ch));
}
llm_bigram_bpe::queue work_queue;
};
-typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{
+struct llm_tokenizer_wpm {
+ llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
+
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+ auto * token_map = &vocab.token_to_id;
+
+ // normalize and split by whitespace
+ std::vector<std::string> words = preprocess(text);
+
+ // bos token prepended already
+
+ // find the longest tokens that form the words
+ for (const std::string &word : words) {
+ // skip empty words
+ if (word.size() == 0) {
+ continue;
+ }
+
+ // prepend phantom space
+ std::string word1 = "\xe2\x96\x81" + word;
+ int n = word1.size();
+
+ // we're at the start of a new word
+ int i = 0;
+ bool match_any = false;
+
+ // move through character position in word
+ while (i < n) {
+ // loop through possible match length
+ bool match = false;
+ for (int j = n; j > i; j--) {
+ auto it = token_map->find(word1.substr(i, j - i));
+ if (it != token_map->end()) {
+ output.push_back(it->second);
+ match = true;
+ match_any = true;
+ i = j;
+ break;
+ }
+ }
+
+ // must be an unknown character
+ if (!match) {
+ i++;
+ }
+ }
+
+ // we didn't find any matches for this word
+ if (!match_any) {
+ output.push_back(vocab.special_unk_id);
+ }
+ }
+
+ // append eos token
+ output.push_back(vocab.special_eos_id);
+ }
+
+ std::vector<std::string> preprocess(const std::string & text) {
+ std::string ori_str = normalize(text);
+ uint64_t ori_size = ori_str.size();
+
+ // single punct / single symbol / single digit
+ // baseline: add whitespace on the left and right of punct and chinese characters
+ std::vector<std::string> words;
+ std::string new_str = "";
+ uint64_t i = 0;
+ while (i < ori_size) {
+ int utf_char_len = utf8_len(ori_str[i]);
+ if ((utf_char_len == 1) && ispunct(ori_str[i])) {
+ new_str += " ";
+ new_str += ori_str[i];
+ new_str += " ";
+ i += 1;
+ }
+ else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) {
+ new_str += " ";
+ new_str += ori_str.substr(i, 3);
+ new_str += " ";
+ i += 3;
+ }
+ else {
+ new_str += ori_str[i];
+ i += 1;
+ }
+ }
+
+ // split by whitespace
+ uint64_t l = 0;
+ uint64_t r = 0;
+ while (r < new_str.size()) {
+ // if is whitespace
+ if (isspace(new_str[r])) {
+ if (r > l) words.push_back(new_str.substr(l, (r - l)));
+ l = r + 1;
+ r = l;
+ }
+ else {
+ r += 1;
+ }
+ }
+ if (r > l) {
+ words.push_back(new_str.substr(l, (r - l)));
+ }
+ return words;
+ }
+
+ std::string normalize(const std::string & text) {
+ // TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98
+ std::string text2 = strip_accents(text);
+ for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) {
+ char c = text2[i];
+ if (c >= 'A' && c <= 'Z') {
+ text2[i] = c - 'A' + 'a';
+ }
+ }
+ return text2;
+ }
+
+ bool is_chinese_char(const std::string & str) {
+ int len = str.length();
+ unsigned int codepoint = 0;
+ int num_bytes = 0;
+ int i = 0;
+ unsigned char ch = static_cast<unsigned char>(str[i]);
+ if (ch <= 0x7f) {
+ codepoint = ch;
+ num_bytes = 1;
+ } else if ((ch >> 5) == 0x06) {
+ codepoint = ch & 0x1f;
+ num_bytes = 2;
+ } else if ((ch >> 4) == 0x0e) {
+ codepoint = ch & 0x0f;
+ num_bytes = 3;
+ } else if ((ch >> 3) == 0x1e) {
+ codepoint = ch & 0x07;
+ num_bytes = 4;
+ }
+ for (int j = 1; j < num_bytes; ++j) {
+ if (i + j >= len) {
+ return false; // incomplete UTF-8 character
+ }
+ unsigned char next_ch = static_cast<unsigned char>(str[i + j]);
+ if ((next_ch >> 6) != 0x02) {
+ return false; // invalid trailing byte
+ }
+ codepoint = (codepoint << 6) | (next_ch & 0x3f);
+ }
+ if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) ||
+ (codepoint >= 0x3400 && codepoint <= 0x4DBF) ||
+ (codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
+ (codepoint >= 0x2A700 && codepoint <= 0x2B73F) ||
+ (codepoint >= 0x2B740 && codepoint <= 0x2B81F) ||
+ (codepoint >= 0x2B920 && codepoint <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
+ (codepoint >= 0xF900 && codepoint <= 0xFAFF) ||
+ (codepoint >= 0x2F800 && codepoint <= 0x2FA1F) ||
+ (codepoint >= 0x3000 && codepoint <= 0x303F) ||
+ (codepoint >= 0xFF00 && codepoint <= 0xFFEF)) {
+ return true; // NOLINT
+ }
+ return false;
+ }
+
+ std::string strip_accents(const std::string & input_string) {
+ std::string resultString;
+ std::map<std::string, char> accent_map = {
+ {"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'},
+ {"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'},
+ {"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'},
+ {"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'},
+ {"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'},
+ {"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'},
+ {"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'},
+ {"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'},
+ {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'},
+ };
+
+ for (size_t i = 0; i < input_string.length();) {
+ int len = utf8_len(input_string[i]);
+ std::string curChar = input_string.substr(i, len);
+ auto iter = accent_map.find(curChar);
+ if (iter != accent_map.end()) {
+ resultString += iter->second;
+ } else {
+ resultString += curChar;
+ }
+ i += len;
+ }
+
+ return resultString;
+ }
+
+ static size_t utf8_len(char src) {
+ const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
+ uint8_t highbits = static_cast<uint8_t>(src) >> 4;
+ return lookup[highbits];
+ }
+
+ const llama_vocab & vocab;
+};
+
+typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
} FRAGMENT_BUFFER_VARIANT_TYPE;
-struct fragment_buffer_variant{
+struct fragment_buffer_variant {
fragment_buffer_variant(llama_vocab::id _token)
:
type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
// #define PRETOKENIZERDEBUG
-static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
-{
+static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
// for each special token
for (const auto & st: vocab.special_tokens_cache) {
const auto & special_token = st.first;
switch (vocab.type) {
case LLAMA_VOCAB_TYPE_SPM:
{
- for (const auto & fragment: fragment_buffer)
- {
- if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
- {
+ for (const auto & fragment: fragment_buffer) {
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
// without adding this leading whitespace, we do not get the same results as the original tokenizer
// TODO: It's likely possible to get rid of this string copy entirely
llm_tokenizer_spm tokenizer(vocab);
llama_escape_whitespace(raw_text);
tokenizer.tokenize(raw_text, output);
- }
- else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
- {
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
}
}
} break;
case LLAMA_VOCAB_TYPE_BPE:
{
- for (const auto & fragment: fragment_buffer)
- {
- if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
- {
+ for (const auto & fragment: fragment_buffer) {
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
#ifdef PRETOKENIZERDEBUG
#endif
llm_tokenizer_bpe tokenizer(vocab);
tokenizer.tokenize(raw_text, output);
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+ output.push_back(fragment.token);
}
- else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
- {
+ }
+ } break;
+ case LLAMA_VOCAB_TYPE_WPM:
+ {
+ for (const auto & fragment: fragment_buffer) {
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+ llm_tokenizer_wpm tokenizer(vocab);
+ tokenizer.tokenize(raw_text, output);
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
}
}
// graph inputs
{
ggml_init_params init_params = {
- /* .mem_size */ ggml_tensor_overhead()*5,
+ /* .mem_size */ ggml_tensor_overhead()*7,
/* .mem_buffer */ nullptr,
/* .no_alloc */ true,
};
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
+ ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch);
ggml_set_name(ctx->inp_tokens, "inp_tokens");
ggml_set_name(ctx->inp_embd, "inp_embd");
ggml_set_name(ctx->inp_pos, "inp_pos");
ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
+ ggml_set_name(ctx->inp_sum, "inp_sum");
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
- ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
// build worst-case graph
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
int n_past = cparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
- ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));
+ ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
// initialize scheduler with the worst-case graph
- ggml_backend_sched_init_measure(ctx->sched, gf);
- ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
+ if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
+ llama_free(ctx);
+ return nullptr;
+ }
- for (ggml_backend_t backend : ctx->backends) {
- ggml_backend_buffer_t buf = ggml_backend_sched_get_buffer(ctx->sched, backend);
+ for (size_t i = 0; i < ctx->backends.size(); i++) {
+ ggml_backend_t backend = ctx->backends[i];
+ ggml_backend_buffer_type_t buft = backend_buft[i];
+ size_t size = ggml_backend_sched_get_buffer_size(ctx->sched, backend);
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
- ggml_backend_buffer_name(buf),
- ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
+ ggml_backend_buft_name(buft),
+ size / 1024.0 / 1024.0);
}
// note: the number of splits during measure is higher than during inference due to the kv shift
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) {
if (0 <= token && token < llama_n_vocab(model)) {
switch (llama_vocab_get_type(model->vocab)) {
+ case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_SPM: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
+ s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | ";
return s.c_str();
}