--- /dev/null
+set( CMAKE_SYSTEM_NAME Darwin )
+set( CMAKE_SYSTEM_PROCESSOR arm64 )
+
+set( target arm64-apple-darwin-macho )
+
+set( CMAKE_C_COMPILER clang )
+set( CMAKE_CXX_COMPILER clang++ )
+
+set( CMAKE_C_COMPILER_TARGET ${target} )
+set( CMAKE_CXX_COMPILER_TARGET ${target} )
+
+set( arch_c_flags "-march=armv8.4-a -fvectorize -ffp-model=fast -fno-finite-math-only" )
+set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function" )
+
+set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
+set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
--- /dev/null
+set( CMAKE_SYSTEM_NAME Windows )
+set( CMAKE_SYSTEM_PROCESSOR arm64 )
+
+set( target arm64-pc-windows-msvc )
+
+set( CMAKE_C_COMPILER clang )
+set( CMAKE_CXX_COMPILER clang++ )
+
+set( CMAKE_C_COMPILER_TARGET ${target} )
+set( CMAKE_CXX_COMPILER_TARGET ${target} )
+
+set( arch_c_flags "-march=armv8.7-a -fvectorize -ffp-model=fast -fno-finite-math-only" )
+set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function -Wno-gnu-zero-variadic-macro-arguments" )
+
+set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
+set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
--- /dev/null
+set(CMAKE_SYSTEM_NAME Linux)
+set(CMAKE_SYSTEM_PROCESSOR riscv64)
+set(CMAKE_SYSTEM_VERSION 1)
+
+if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)")
+ message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}")
+else()
+ set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple")
+ if (DEFINED ENV{RISCV_ROOT_PATH})
+ file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH)
+ else()
+ message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined")
+ endif()
+
+ set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain")
+ set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc)
+ set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++)
+ set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip)
+ set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu")
+ set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot")
+endif()
+
+set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
+set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
+set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
+set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
+set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}")
+set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}")
+set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic")
--- /dev/null
+set( CMAKE_SYSTEM_NAME Windows )
+set( CMAKE_SYSTEM_PROCESSOR x86_64 )
+
+set( CMAKE_C_COMPILER clang )
+set( CMAKE_CXX_COMPILER clang++ )
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
+ { LLM_ARCH_AFMOE, "afmoe" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_AFMOE,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
+ { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
+ { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
+ },
+ },
{
LLM_ARCH_LLAMA4,
{
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
LLM_ARCH_BAILINGMOE2,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
+ LLM_ARCH_AFMOE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_HUNYUAN_MOE,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_ATTN_SINKS,
+ LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
+ // expand k later to enable rope fusion which directly writes into k-v cache
ggml_build_forward_expand(gf, q_cur);
- ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
+ ggml_build_forward_expand(gf, k_cur);
const auto * mctx_cur = inp->mctx;
p1 = std::numeric_limits<llama_pos>::max();
}
- // models like Mamba or RWKV can't have a state partially erased
+ // models like Mamba or RWKV can't have a state partially erased at the end
+ // of the sequence because their state isn't preserved for previous tokens
if (seq_id >= (int64_t) size) {
// could be fatal
return false;
int32_t & tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
const auto & cell = cells[tail_id];
- // partial intersection is invalid
- if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+ // partial intersection is invalid if it includes the final pos
+ if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
return false;
}
case LLM_TYPE_15B: return "15B";
case LLM_TYPE_16B: return "16B";
case LLM_TYPE_20B: return "20B";
+ case LLM_TYPE_26B: return "26B";
case LLM_TYPE_27B: return "27B";
case LLM_TYPE_30B: return "30B";
case LLM_TYPE_32B: return "32B";
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_AFMOE:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
+ ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+
+ // Set up interleaved sliding window attention (ISWA)
+ // Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4)
+ if (hparams.n_swa > 0) {
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+ hparams.set_swa_pattern(4);
+ } else {
+ hparams.swa_type = LLAMA_SWA_TYPE_NONE;
+ }
+
+ // Default to sigmoid if not set
+ if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
+ hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
+ }
+
+ switch (hparams.n_layer) {
+ case 56: type = LLM_TYPE_6B; break;
+ case 32: type = LLM_TYPE_26B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_DECI:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
+ case LLM_ARCH_AFMOE:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+ // if output is NULL, init from the input tok embed
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+
+ const int64_t n_ff_exp = hparams.n_ff_exp;
+ const int64_t n_expert_shared = hparams.n_expert_shared;
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ // dual attention normalization
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+ // attention projections
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+ // Q/K normalization
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+
+ // attention gating
+ layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+
+ // dual ffn normalization
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+
+ if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) {
+ // MoE layers
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
+
+ // grouped expert weights
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
+
+ // shared expert
+ if (n_expert_shared > 0) {
+ const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
+ }
+ } else {
+ // Dense layers
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ }
+ }
+ } break;
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
{
{
llm = std::make_unique<llm_build_arcee>(*this, params);
} break;
+ case LLM_ARCH_AFMOE:
+ {
+ llm = std::make_unique<llm_build_afmoe>(*this, params);
+ } break;
case LLM_ARCH_ERNIE4_5:
{
llm = std::make_unique<llm_build_ernie4_5>(*this, params);
case LLM_ARCH_MINIMAX_M2:
case LLM_ARCH_COGVLM:
case LLM_ARCH_PANGU_EMBED:
+ case LLM_ARCH_AFMOE:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
LLM_TYPE_15B,
LLM_TYPE_16B,
LLM_TYPE_20B,
+ LLM_TYPE_26B,
LLM_TYPE_27B,
LLM_TYPE_30B,
LLM_TYPE_32B,
struct ggml_tensor * wk_enc = nullptr;
struct ggml_tensor * wv_enc = nullptr;
struct ggml_tensor * wo_enc = nullptr;
+ struct ggml_tensor * wqkv_gate = nullptr;
// attention bias
struct ggml_tensor * bq = nullptr;
#include "llama-vocab.h"
#include "llama-grammar.h"
+#include <array>
#include <algorithm>
#include <cassert>
#include <cfloat>
auto * ctx = new llama_sampler_grammar;
if (grammar_str != nullptr && grammar_str[0] != '\0') {
+ std::string trigger_pattern;
+ llama_grammar * grammar = nullptr;
// TODO: remove trigger_words support.
if (trigger_words != nullptr && num_trigger_words > 0) {
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
- std::string trigger_pattern("[\\s\\S]*?(");
+ trigger_pattern = "[\\s\\S]*?(";
for (size_t i = 0; i < num_trigger_words; ++i) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
if (i > 0) {
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
}
trigger_pattern += ")[\\s\\S]*";
- const auto * trigger_pattern_c = trigger_pattern.c_str();
- trigger_patterns = &trigger_pattern_c;
- num_trigger_patterns = 1;
+
+ std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
+ } else {
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
}
*ctx = {
/* .vocab = */ vocab,
/* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root,
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
+ /* .grammar = */ grammar,
};
if (!ctx->grammar) {
delete ctx;
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
+ case LLAMA_VOCAB_PRE_TYPE_AFMOE:
+ regex_exprs = {
+ // Digit handling - uses custom implementation in unicode.cpp
+ // Groups digits with leading 1-2 based on total length modulo 3
+ "\\p{AFMoE_digits}",
+ // CJK and Asian scripts (using direct Unicode literals)
+ "[一-鿿㐀-䶿豈--ゟ゠-ヿ・-゚⼀-เ--ក-က-႟ꩠ-ꩿꧠ-가-ᄀ-ᇿ]+",
+ // Main BPE pattern
+ "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+ };
+ break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
}
private:
uint32_t get_node(size_t index) {
- if (index > xcda_array_size) {
+ if (index >= xcda_array_size) {
throw std::runtime_error("Index out of array bounds in XCDA array!");
}
return xcda_array[index];
tokenizer_pre == "grok-2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
clean_spaces = false;
+ } else if (
+ tokenizer_pre == "afmoe") {
+ pre_type = LLAMA_VOCAB_PRE_TYPE_AFMOE;
+ clean_spaces = false;
} else if (
tokenizer_pre == "minimax-m2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
+ LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
};
struct LLM_KV;
--- /dev/null
+#include "models.h"
+
+llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // MuP scaling: embeddings * sqrt(hidden_size)
+ // mup_enabled = true, hidden_size = 1024, scale = 32.0
+ inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd)));
+ cb(inpL, "inp_embd_scaled", -1);
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+ auto * inp_attn = build_attn_inp_kv_iswa();
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * inpSA = inpL;
+
+ // dual attention normalization (pre)
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ ggml_tensor * attn_inp = cur; // save input for gate computation
+
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ // compute gate from input
+ ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp);
+ cb(gate, "attn_gate_proj", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+ // Q/K normalization
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+ cb(Qcur, "Qcur_normed", il);
+ cb(Kcur, "Kcur_normed", il);
+
+ // RoPE only for sliding_attention layers
+ const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
+ ((il + 1) % hparams.n_no_rope_layer_step) != 0;
+ if (use_rope) {
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ cb(Qcur, "Qcur_rope", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ cb(Kcur, "Kcur_rope", il);
+ }
+
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ cur = build_attn(inp_attn,
+ NULL, NULL, // wo will be applied after gating
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+ cb(cur, "attn_out", il);
+
+ // attention gating: attn_out * sigmoid(gate) BEFORE o_proj
+ gate = ggml_sigmoid(ctx0, gate);
+ cb(gate, "attn_gate_sig", il);
+ cur = ggml_mul(ctx0, cur, gate);
+ cb(cur, "attn_gated", il);
+
+ // now apply output projection
+ cur = build_lora_mm(model.layers[il].wo, cur);
+ cb(cur, "attn_o_proj", il);
+ }
+
+ // dual attention normalization (post)
+ cur = build_norm(cur,
+ model.layers[il].attn_post_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_post_norm", il);
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // dual ffn normalization (pre)
+ cur = build_norm(ffn_inp,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ // MoE or dense FFN
+ if ((uint32_t)il >= hparams.n_layer_dense_lead) {
+ // MoE layer with sigmoid routing, normalization, and scaling
+ ggml_tensor * moe_out = build_moe_ffn(cur,
+ model.layers[il].ffn_gate_inp,
+ model.layers[il].ffn_up_exps,
+ model.layers[il].ffn_gate_exps,
+ model.layers[il].ffn_down_exps,
+ model.layers[il].ffn_exp_probs_b,
+ n_expert, n_expert_used,
+ LLM_FFN_SILU,
+ hparams.expert_weights_norm, // norm_w (route_norm=True)
+ hparams.expert_weights_scale, // scale_w
+ hparams.expert_weights_scale, // w_scale (route_scale=2.826)
+ (llama_expert_gating_func_type) hparams.expert_gating_func,
+ il);
+ cb(moe_out, "ffn_moe_out", il);
+
+ // shared expert
+ if (hparams.n_expert_shared > 0) {
+ ggml_tensor * ffn_shexp = build_ffn(cur,
+ model.layers[il].ffn_up_shexp, NULL, NULL,
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
+ model.layers[il].ffn_down_shexp, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
+ cb(ffn_shexp, "ffn_shexp", il);
+
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
+ cb(cur, "ffn_out", il);
+ } else {
+ cur = moe_out;
+ }
+ } else {
+ // dense layer
+ cur = build_ffn(cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
+ cb(cur, "ffn_out", il);
+ }
+
+ // dual ffn normalization (post)
+ cur = build_norm(cur,
+ model.layers[il].ffn_post_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_post_norm", il);
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+ cb(cur, "result_norm", -1);
+
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+}
#include "models.h"
-
-
llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
auto * inp_attn = build_attn_inp_kv();
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
- ggml_tensor * inp_out_ids = build_inp_out_ids();
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
- inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
int il) const;
};
+struct llm_build_afmoe : public llm_graph_context {
+ llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
+};
+
struct llm_build_apertus : public llm_graph_context {
llm_build_apertus(const llama_model & model, const llm_graph_params & params);
};
auto * inp_attn = build_attn_inp_kv_iswa();
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
- ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
return bpe_offsets;
}
+// AFMOE digit handling: splits digits with leading 1-2 based on total length modulo 3
+static std::vector<size_t> unicode_regex_split_custom_afmoe(const std::string & text, const std::vector<size_t> & offsets) {
+ std::vector<size_t> bpe_offsets;
+ bpe_offsets.reserve(offsets.size());
+
+ const auto cpts = unicode_cpts_from_utf8(text);
+
+ size_t start = 0;
+ for (auto offset : offsets) {
+ const size_t offset_ini = start;
+ const size_t offset_end = start + offset;
+ assert(offset_end <= cpts.size());
+ start = offset_end;
+
+ auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
+ };
+
+ size_t _prev_end = offset_ini;
+ auto _add_token = [&] (const size_t end) -> size_t {
+ assert(_prev_end <= end && end <= offset_end);
+ size_t len = end - _prev_end;
+ if (len > 0) {
+ bpe_offsets.push_back(len);
+ }
+ _prev_end = end;
+ return len;
+ };
+
+ for (size_t pos = offset_ini; pos < offset_end; ) {
+ const auto flags = _get_flags(pos);
+
+ // Handle digit sequences with special splitting logic
+ if (flags.is_number) {
+ size_t digit_start = pos;
+ size_t digit_count = 0;
+
+ // Count consecutive digits
+ while (_get_flags(pos).is_number && pos < offset_end) {
+ digit_count++;
+ pos++;
+ }
+
+ // Split based on total length modulo 3
+ size_t remainder = digit_count % 3;
+ size_t current = digit_start;
+
+ // Emit leading 1-2 digits if needed
+ if (remainder > 0) {
+ _add_token(current + remainder);
+ current += remainder;
+ }
+
+ // Emit groups of 3
+ while (current < digit_start + digit_count) {
+ _add_token(current + 3);
+ current += 3;
+ }
+ continue;
+ }
+
+ // For non-digits, just move forward
+ pos++;
+ }
+
+ // Add any remaining content
+ if (_prev_end < offset_end) {
+ _add_token(offset_end);
+ }
+ }
+
+ return bpe_offsets;
+}
+
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;
} else if (regex_expr == "\\p{Han}+") {
// K2's first pattern - handle all K2 patterns together
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
+ } else if (regex_expr == "\\p{AFMoE_digits}") {
+ // AFMOE digit pattern - use custom implementation for proper splitting
+ bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
}
return bpe_offsets;