"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
- "model.embed_tokens", # llama-hf nemotron olmoe
+ "model.embed_tokens", # llama-hf nemotron olmoe olmo_1124
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
- "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe
+ "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo_1124
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
- "model.norm", # llama-hf baichuan internlm2 olmoe
+ "model.norm", # llama-hf baichuan internlm2 olmoe olmo_1124
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
# Attention query
MODEL_TENSOR.ATTN_Q: (
- "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe
+ "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo_1124
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j
# Attention key
MODEL_TENSOR.ATTN_K: (
- "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe
+ "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo_1124
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j
# Attention value
MODEL_TENSOR.ATTN_V: (
- "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe
+ "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo_1124
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
- "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe
+ "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo_1124
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
),
MODEL_TENSOR.ATTN_POST_NORM: (
- "model.layers.{bid}.post_attention_layernorm", # gemma2
+ "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo_1124
),
# Rotary embeddings
# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
- "model.layers.{bid}.post_feedforward_layernorm", # gemma2
+ "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo_1124
),
MODEL_TENSOR.FFN_GATE_INP: (
"transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"h.{bid}.mlp.dense_h_to_4h", # bloom
- "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron
+ "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo_1124
"layers.{bid}.feed_forward.w3", # llama-pth
"encoder.layer.{bid}.intermediate.dense", # bert
"transformer.h.{bid}.mlp.fc_in", # gpt-j
# Feed-forward gate
MODEL_TENSOR.FFN_GATE: (
- "model.layers.{bid}.mlp.gate_proj", # llama-hf refact
+ "model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo_1124
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
"transformer.h.{bid}.mlp.c_fc2", # jais
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom
- "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron
+ "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo_1124
"layers.{bid}.feed_forward.w2", # llama-pth
"encoder.layer.{bid}.output.dense", # bert
"transformer.h.{bid}.mlp.fc_out", # gpt-j
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
- "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon
+ "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo_1124
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
- "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon
+ "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo_1124
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm
LLM_ARCH_COMMAND_R,
LLM_ARCH_DBRX,
LLM_ARCH_OLMO,
+ LLM_ARCH_OLMO_1124,
LLM_ARCH_OLMOE,
LLM_ARCH_OPENELM,
LLM_ARCH_ARCTIC,
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_DBRX, "dbrx" },
{ LLM_ARCH_OLMO, "olmo" },
+ { LLM_ARCH_OLMO_1124, "olmo_1124" },
{ LLM_ARCH_OLMOE, "olmoe" },
{ LLM_ARCH_OPENELM, "openelm" },
{ LLM_ARCH_ARCTIC, "arctic" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_OLMO_1124,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_OLMOE,
{
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_OLMO_1124:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 16: model.type = e_model::MODEL_1B; break;
+ case 32: model.type = e_model::MODEL_7B; break;
+ case 40: model.type = e_model::MODEL_13B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_OLMOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
+ case LLM_ARCH_OLMO_1124:
+ {
+ model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = model.layers[i];
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+ }
+ } break;
case LLM_ARCH_OLMOE:
{
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
return gf;
}
+ struct ggml_cgraph * build_olmo_1124() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ cur = inpL;
+
+ // self_attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(Qcur, "Qcur_normed", il);
+
+ Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(Kcur, "Kcur_normed", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur_rope", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur_rope", il);
+
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ model.layers[il].wo, NULL,
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ }
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].attn_post_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_post_norm", il);
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ cur = llm_build_ffn(ctx0, lctx, ffn_inp,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].ffn_post_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "ffn_post_norm", -1);
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "ffn_out", il);
+
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
// based on the build_qwen2moe() function, changes:
// * removed shared experts
// * removed bias
{
result = llm.build_olmo();
} break;
+ case LLM_ARCH_OLMO_1124:
+ {
+ result = llm.build_olmo_1124();
+ } break;
case LLM_ARCH_OLMOE:
{
result = llm.build_olmoe();
case LLM_ARCH_QWEN:
case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2MOE:
+ case LLM_ARCH_OLMO_1124:
case LLM_ARCH_OLMOE:
case LLM_ARCH_PHI2:
case LLM_ARCH_PHI3: