From: Georgi Gerganov Date: Sun, 12 Oct 2025 05:37:14 +0000 (+0300) Subject: talk-llama : sync llama.cpp X-Git-Tag: upstream/1.8.2~26 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=ff4c1a5a53887829d1eed250f554c021bfcd170b;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp talk-llama : sync llama.cpp --- diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 4e8d54c4..869e4dcc 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -93,12 +93,14 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, + { LLM_ARCH_LFM2MOE, "lfm2moe" }, { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_LLADA, "llada" }, { LLM_ARCH_LLADA_MOE, "llada-moe" }, { LLM_ARCH_SEED_OSS, "seed_oss" }, { LLM_ARCH_GROVEMOE, "grovemoe" }, + { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -217,6 +219,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, + // sentence-transformers dense modules feature dims + { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -256,6 +263,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" }, { LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" }, + { LLM_KV_XIELU_ALPHA_N, "xielu.alpha_n" }, + { LLM_KV_XIELU_ALPHA_P, "xielu.alpha_p" }, + { LLM_KV_XIELU_BETA, "xielu.beta" }, + { LLM_KV_XIELU_EPS, "xielu.eps" }, + // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, @@ -1064,6 +1076,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, + { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -2098,6 +2112,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT, "output" }, } }, + { + LLM_ARCH_LFM2MOE, + { + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, + { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, + { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + } + }, { LLM_ARCH_SMALLTHINKER, { @@ -2119,6 +2159,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" } }, }, + { + LLM_ARCH_APERTUS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_DREAM, { @@ -2229,6 +2288,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -2468,6 +2529,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: return true; default: diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index b5c6f3d7..c3ae7165 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -97,12 +97,14 @@ enum llm_arch { LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, + LLM_ARCH_LFM2MOE, LLM_ARCH_DREAM, LLM_ARCH_SMALLTHINKER, LLM_ARCH_LLADA, LLM_ARCH_LLADA_MOE, LLM_ARCH_SEED_OSS, LLM_ARCH_GROVEMOE, + LLM_ARCH_APERTUS, LLM_ARCH_UNKNOWN, }; @@ -260,10 +262,21 @@ enum llm_kv { LLM_KV_SHORTCONV_L_CACHE, + LLM_KV_XIELU_ALPHA_N, + LLM_KV_XIELU_ALPHA_P, + LLM_KV_XIELU_BETA, + LLM_KV_XIELU_EPS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, + + // sentence-transformers dense layers in and out features + LLM_KV_DENSE_2_FEAT_IN, + LLM_KV_DENSE_2_FEAT_OUT, + LLM_KV_DENSE_3_FEAT_IN, + LLM_KV_DENSE_3_FEAT_OUT, }; enum llm_tensor { @@ -271,6 +284,8 @@ enum llm_tensor { LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, + LLM_TENSOR_DENSE_2_OUT, + LLM_TENSOR_DENSE_3_OUT, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 66e6c6a3..956c4e08 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -590,7 +590,7 @@ int32_t llm_chat_apply_template( ss << message->content << "<|end_of_text|>\n"; } if (add_ass) { - ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + ss << "<|start_of_role|>assistant<|end_of_role|>"; } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index d8a8b5e6..e7526e7d 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2346,6 +2346,12 @@ llama_context * llama_init_from_model( return nullptr; } + if (params.pooling_type != model->hparams.pooling_type) { + //user-specified pooling-type is different from the model default + LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, + model->hparams.pooling_type, params.pooling_type); + } + try { auto * ctx = new llama_context(*model, params); return ctx; diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 90cd885a..a24853c6 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +void llm_graph_context::build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const { + if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) { + return; + } + ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; + GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd"); + + cur = ggml_mul_mat(ctx0, dense_2, cur); + cur = ggml_mul_mat(ctx0, dense_3, cur); + cb(cur, "result_embd_pooled", -1); + res->t_embd_pooled = cur; + ggml_build_forward_expand(gf, cur); +} + + void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 34b984af..dc84b794 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -814,6 +814,14 @@ struct llm_graph_context { ggml_tensor * cls_b, ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + + // + // dense (out) + // + + void build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const; }; // TODO: better name diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index c04ac58f..db65d69e 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -140,7 +140,11 @@ uint32_t llama_hparams::n_embd_s() const { } bool llama_hparams::is_recurrent(uint32_t il) const { - return recurrent_layer_arr[il]; + if (il < n_layer) { + return recurrent_layer_arr[il]; + } + + GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer); } uint32_t llama_hparams::n_pos_per_embd() const { diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 0fe4b569..4e7f73ec 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -42,7 +42,7 @@ struct llama_hparams { uint32_t n_embd; uint32_t n_embd_features = 0; uint32_t n_layer; - int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head @@ -169,6 +169,18 @@ struct llama_hparams { uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; + // needed for sentence-transformers dense layers + uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense + uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense + uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense + uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense + + // xIELU + std::array xielu_alpha_n; + std::array xielu_alpha_p; + std::array xielu_beta; + std::array xielu_eps; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/examples/talk-llama/llama-kv-cache-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp index 827302e6..facba1d0 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-iswa.cpp @@ -220,7 +220,7 @@ bool llama_kv_cache_iswa::get_can_shift() const { } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { - if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_write(io, seq_id, flags); } @@ -228,7 +228,7 @@ void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id } void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { - if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_read(io, seq_id, flags); } diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 816f2d5d..736693e1 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index abf65248..dfb8439e 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -175,17 +177,17 @@ std::map llama_memory_hybrid::memory_breakdo } void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { - GGML_UNUSED(flags); - - mem_attn->state_write(io, seq_id); - mem_recr->state_write(io, seq_id); + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_write(io, seq_id, flags); + } + mem_recr->state_write(io, seq_id, flags); } void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { - GGML_UNUSED(flags); - - mem_attn->state_read(io, seq_id); - mem_recr->state_read(io, seq_id); + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_read(io, seq_id, flags); + } + mem_recr->state_read(io, seq_id, flags); } llama_kv_cache * llama_memory_hybrid::get_mem_attn() const { diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 44645fcd..d67f5a5f 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) { } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (tail_id >= 0) { const auto & cell = cells[tail_id]; // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); return false; } // invalidate tails which will be cleared @@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } else { // seq_id is negative, then the range should include everything or nothing if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n"); return false; } } @@ -379,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -856,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); + if (cell_count == 0) { + return true; + } + llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 8182a9ad..aa3a65f8 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -465,6 +465,8 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index ffd9286e..36d495d6 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -114,6 +114,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_A13B: return "A13B"; + case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; @@ -310,7 +311,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; // add ACCEL buffer types @@ -331,11 +332,13 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // generally, this will be done using the first device in the list // a better approach would be to handle this on a weight-by-weight basis using the offload_op // function of the device to determine if it would benefit from being stored in a host buffer - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); - if (buft) { - buft_list.emplace_back(dev, buft); - break; + if (!no_host) { + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } } } @@ -512,9 +515,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); + std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -1084,7 +1091,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; default: type = LLM_TYPE_UNKNOWN; - } + } + + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; case LLM_ARCH_GPT2: { @@ -1207,12 +1218,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(6); hparams.causal_attn = false; // embeddings do not use causal attention - hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + + GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); + GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); switch (hparams.n_layer) { case 24: type = LLM_TYPE_0_3B; break; @@ -1985,13 +2005,28 @@ void llama_model::load_hparams(llama_model_loader & ml) { for (uint32_t il = 0; il < hparams.n_layer; ++il) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } + hparams.n_layer_dense_lead = hparams.n_layer; switch (hparams.n_ff()) { case 4608: type = LLM_TYPE_350M; break; case 6912: type = LLM_TYPE_700M; break; case 8192: type = LLM_TYPE_1_2B; break; case 10752: type = LLM_TYPE_2_6B; break; - default: type = LLM_TYPE_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_LFM2MOE: + { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } + + type = LLM_TYPE_8B_A1B; } break; case LLM_ARCH_SMALLTHINKER: { @@ -2029,6 +2064,19 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_APERTUS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2062,7 +2110,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts); + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -3392,17 +3440,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLAMO2: { + // mamba parameters const uint32_t d_conv = hparams.ssm_d_conv; const uint32_t d_state = hparams.ssm_d_state; const uint32_t num_heads = hparams.ssm_dt_rank; const uint32_t intermediate_size = hparams.ssm_d_inner; - const uint32_t head_dim = intermediate_size / num_heads; - const uint32_t qk_dim = head_dim; - const uint32_t v_dim = head_dim; - const int64_t num_attention_heads = hparams.n_head(); - const int64_t q_num_heads = num_attention_heads; const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + // attention parameters + const uint32_t qk_dim = hparams.n_embd_head_k; + const uint32_t v_dim = hparams.n_embd_head_v; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3436,6 +3484,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); } else { + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t q_num_heads = num_attention_heads; const int64_t num_key_value_heads = hparams.n_head_kv(i); const int64_t k_num_heads = num_key_value_heads; const int64_t v_num_heads = num_key_value_heads; @@ -3444,8 +3494,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t v_proj_dim = v_num_heads * v_dim; layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); } @@ -3645,6 +3695,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4825,11 +4880,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); } } } @@ -5787,6 +5844,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); @@ -5798,11 +5856,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - // ffn is same for transformer and conv layers + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -5907,6 +5977,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); } } break; + case LLM_ARCH_APERTUS: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // Q and K layernorms for Apertus + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6241,7 +6353,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } - if (arch == LLM_ARCH_SMALLTHINKER) { + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } @@ -7776,6 +7888,8 @@ struct llm_build_bert : public llm_graph_context { } if (model.layers[il].attn_q_norm) { + Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -7785,6 +7899,8 @@ struct llm_build_bert : public llm_graph_context { } if (model.layers[il].attn_k_norm) { + Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, @@ -8167,6 +8283,9 @@ struct llm_build_mpt : public llm_graph_context { // Q/K Layernorm if (model.layers[il].attn_q_norm) { + Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens); + Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -11751,6 +11870,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + cb(y, "mamba2_y_add_d", il); y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm @@ -14705,6 +14825,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + ggml_build_forward_expand(gf, inpL); auto * inp = build_inp_mem_hybrid(); @@ -14736,7 +14857,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { // add residual cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "block_out", il); + cb(cur, "nemotron_h_block_out", il); // input for next layer inpL = cur; @@ -16192,10 +16313,10 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il) { + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il) { // For Granite architectures - scale residual if (hparams.f_residual_scale) { @@ -17607,6 +17728,7 @@ private: const int64_t n_embd_head_q = hparams.n_embd_head_k; const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head = hparams.n_head(il); int32_t n_head_kv = hparams.n_head_kv(il); const int64_t q_offset = 0; @@ -18523,6 +18645,8 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const bool is_moe_layer = il >= static_cast(hparams.n_layer_dense_lead); + auto * prev_cur = cur; cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); @@ -18537,7 +18661,16 @@ struct llm_build_lfm2 : public llm_graph_context { } cur = ggml_add(ctx0, prev_cur, cur); - cur = ggml_add(ctx0, cur, build_feed_forward(cur, il)); + + auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(ffn_norm_out, "model.layers.{}.ffn_norm", il); + + ggml_tensor * ffn_out = is_moe_layer ? + build_moe_feed_forward(ffn_norm_out, il) : + build_dense_feed_forward(ffn_norm_out, il); + cb(ffn_norm_out, "model.layers.{}.ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); } cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); @@ -18552,23 +18685,32 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - ggml_tensor * build_feed_forward(ggml_tensor * cur, - int il) const { - cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "model.layers.{}.ffn_norm", il); + ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, + int il) const { + return build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + static_cast(hparams.expert_gating_func), + il); + } + ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, + int il) const { GGML_ASSERT(!model.layers[il].ffn_up_b); GGML_ASSERT(!model.layers[il].ffn_gate_b); GGML_ASSERT(!model.layers[il].ffn_down_b); - cur = build_ffn(cur, + return build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "model.layers.{}.feed_forward.w2", il); - - return cur; } ggml_tensor * build_attn_block(ggml_tensor * cur, @@ -19088,6 +19230,141 @@ struct llm_build_grovemoe : public llm_graph_context { } }; +struct llm_build_apertus : public llm_graph_context { + llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); + cb(Vcur, "Vcur_pos", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network with xIELU activation + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // Up projection + ggml_tensor * up = build_lora_mm(model.layers[il].ffn_up, cur); + cb(up, "ffn_up", il); + + float alpha_n_val = hparams.xielu_alpha_n[il]; + float alpha_p_val = hparams.xielu_alpha_p[il]; + float beta_val = hparams.xielu_beta[il]; + float eps_val = hparams.xielu_eps[il]; + + // Apply xIELU activation + ggml_tensor * activated = ggml_xielu(ctx0, up, alpha_n_val, alpha_p_val, beta_val, eps_val); + cb(activated, "ffn_xielu", il); + + // Down projection + cur = build_lora_mm(model.layers[il].ffn_down, activated); + cb(cur, "ffn_down", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -19603,6 +19880,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { llm = std::make_unique(*this, params); } break; @@ -19618,6 +19896,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_APERTUS: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -19625,6 +19907,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // if the gguf model was converted with --sentence-transformers-dense-modules + // there will be two additional dense projection layers + // dense linear projections are applied after pooling + // TODO: move reranking logic here and generalize + llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + return llm->res->get_gf(); } @@ -19649,6 +19937,7 @@ llama_model_params llama_model_default_params() { /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, + /*.no_host =*/ false, }; return result; @@ -19820,10 +20109,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_SMALLTHINKER: case LLM_ARCH_GLM4_MOE: case LLM_ARCH_SEED_OSS: case LLM_ARCH_GROVEMOE: + case LLM_ARCH_APERTUS: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -19934,6 +20225,10 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); +} + bool llama_model_is_diffusion(const llama_model * model) { return llm_arch_is_diffusion(model->arch); } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index d73ce969..7f48662f 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -107,6 +107,7 @@ enum llm_type { LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_A13B, + LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_106B_A12B, // GLM-4.5-Air @@ -380,6 +381,12 @@ struct llama_layer { // openai-moe struct ggml_tensor * attn_sinks = nullptr; + // xIELU activation parameters for Apertus + struct ggml_tensor * ffn_act_alpha_n = nullptr; + struct ggml_tensor * ffn_act_alpha_p = nullptr; + struct ggml_tensor * ffn_act_beta = nullptr; + struct ggml_tensor * ffn_act_eps = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -431,6 +438,12 @@ struct llama_model { std::vector layers; + //Dense linear projections for SentenceTransformers models like embeddinggemma + // For Sentence Transformers models structure see + // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; + llama_model_params params; // gguf metadata diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampling.cpp index 2186f827..55d2e355 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampling.cpp @@ -2541,8 +2541,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ if (n_non_eog == 0) { cur_p->size = 1; cur_p->data[0].id = ctx->vocab->token_eot(); + if (cur_p->data[0].id == LLAMA_TOKEN_NULL) { + cur_p->data[0].id = ctx->vocab->token_eos(); + } cur_p->data[0].logit = 1.0f; + GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL); + return; } diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index da938af0..7fffd171 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -347,6 +347,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: case LLAMA_VOCAB_PRE_TYPE_TRILLION: + case LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -1961,6 +1962,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "trillion") { pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; clean_spaces = false; + } else if ( + tokenizer_pre == "granite-docling") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING; + clean_spaces = false; } else if ( tokenizer_pre == "bailingmoe" || tokenizer_pre == "llada-moe") { @@ -2166,6 +2171,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end|>" || t.first == "" || t.first == "<|endoftext|>" + || t.first == "<|end_of_text|>" // granite || t.first == "" || t.first == "_" || t.first == "<|end▁of▁sentence|>" // DeepSeek diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 0d2f28c3..5e468675 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -8,46 +8,47 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, - LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 452d9ec5..a0a660bf 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -296,6 +296,7 @@ extern "C" { bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) + bool no_host; // bypass host buffer allowing extra buffers to be used }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations @@ -543,6 +544,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid (like Jamba, Granite, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); @@ -791,8 +795,12 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +// for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 +// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) +#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 + typedef uint32_t llama_state_seq_flags; LLAMA_API size_t llama_state_seq_get_size_ext(