LLM_ARCH_RWKV6,
LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE,
+ LLM_ARCH_CHAMELEON,
LLM_ARCH_UNKNOWN,
};
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
+ { LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
+ LLM_KV_SWIN_NORM,
LLM_KV_RESCALE_EVERY_N_LAYERS,
LLM_KV_TIME_MIX_EXTRA_DIM,
LLM_KV_TIME_DECAY_EXTRA_DIM,
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
+ { LLM_KV_SWIN_NORM, "%s.swin_norm" },
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
{ LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" },
{ LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
+ {
+ LLM_ARCH_CHAMELEON,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
+ },
+ },
{
LLM_ARCH_UNKNOWN,
{
bool vocab_only;
bool rope_finetuned;
bool use_par_res;
+ bool swin_norm;
uint32_t n_vocab;
uint32_t n_ctx_train; // context size the model was trained on
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_CHAMELEON:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default
+ ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm);
+
+ switch (hparams.n_layer) {
+ case 32: model.type = e_model::MODEL_7B; break;
+ case 48: model.type = e_model::MODEL_34B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
} else if (
tokenizer_pre == "exaone") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE;
+ } else if (
+ tokenizer_pre == "chameleon") {
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
+ vocab.tokenizer_add_bos = true;
+ vocab.tokenizer_clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
}
} break;
+ case LLM_ARCH_CHAMELEON:
+ {
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+ // if output is NULL, init from the input tok embed
+ if (model.output == NULL) {
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+ }
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+ layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head});
+ layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv});
+ layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+
+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
return gf;
}
+
+ // ref: https://github.com/facebookresearch/chameleon
+ // based on the original build_llama() function, changes:
+ // * qk-norm
+ // * swin-norm
+ // * removed bias
+ // * removed MoE
+ struct ggml_cgraph * build_chameleon() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ if (hparams.swin_norm) {
+ cur = inpL;
+ } else {
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+ }
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ if (model.layers[il].attn_q_norm) {
+ Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
+ ggml_element_size(Qcur) * n_embd_head,
+ ggml_element_size(Qcur) * n_embd_head * n_head,
+ 0);
+ cb(Qcur, "Qcur", il);
+
+ Qcur = llm_build_norm(ctx0, Qcur, hparams,
+ model.layers[il].attn_q_norm,
+ model.layers[il].attn_q_norm_b,
+ LLM_NORM, cb, il);
+ cb(Qcur, "Qcur", il);
+ }
+
+ if (model.layers[il].attn_k_norm) {
+ Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
+ ggml_element_size(Kcur) * n_embd_head,
+ ggml_element_size(Kcur) * n_embd_head * n_head_kv,
+ 0);
+ cb(Kcur, "Kcur", il);
+
+ Kcur = llm_build_norm(ctx0, Kcur, hparams,
+ model.layers[il].attn_k_norm,
+ model.layers[il].attn_k_norm_b,
+ LLM_NORM, cb, il);
+ cb(Kcur, "Kcur", il);
+ }
+
+ Qcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+ model.layers[il].wo, nullptr,
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+
+ if (hparams.swin_norm) {
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ }
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ if (!hparams.swin_norm) {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+ }
+
+ cur = llm_build_ffn(ctx0, lctx, cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+
+ if (hparams.swin_norm) {
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "ffn_out", il);
+
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+ cb(cur, "result_output_with_img_logits", -1);
+
+ // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs.
+ // Needs to be removed once image outputs are supported.
+ int img_token_end_idx = 8196;
+ int img_token_start_idx = 4;
+ int num_img_tokens = img_token_end_idx - img_token_start_idx;
+ // creates 1d tensor of size num_img_tokens and values -FLT_MAX,
+ // which ensures that text token values are always at least larger than image token values
+ struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens);
+ img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX);
+ cb(img_logits, "img_logits", -1);
+ cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
{
result = llm.build_rwkv6();
} break;
+ case LLM_ARCH_CHAMELEON:
+ {
+ result = llm.build_chameleon();
+ } break;
default:
GGML_ABORT("fatal error");
}
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
+ case LLM_ARCH_CHAMELEON:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2