]
# n_group and d_inner are used during reshape_tensors for mamba2
- self.d_model = self.find_hparam(["hidden_size", "d_model"])
- self.n_group = self.find_hparam(["n_groups"])
- self.d_inner = self.find_hparam(["expand"]) * self.d_model
+ # NOTE: Explicitly include hparam prefix prefix for d_model to
+ # disambiguate with top-level head_dim
+ # NOTE 2: If needed for future models, this can be isolated in a method
+ # to separate the prefix setting and teh keys used
+ self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"])
+ self.n_group = self.find_hparam(["n_groups", "num_groups"])
+ self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model
def get_attn_layers(self):
# Explicit list of layer type names
## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
- self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
+ self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
# in llama.cpp
- self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
+ self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"]))
## Attention params ##
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
Mamba2Model.set_vocab(self)
+@ModelBase.register("NemotronHForCausalLM")
+class NemotronHModel(GraniteHybridModel):
+ """Hybrid mamba2/attention model from NVIDIA"""
+ model_arch = gguf.MODEL_ARCH.NEMOTRON_H
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Save the top-level head_dim for later
+ self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim"))
+ assert self.head_dim is not None, "Could not find the attention head dim in config"
+
+ # Don't use expand to calculate d_inner
+ self.d_inner = self.find_hparam(["num_heads"]) * self.d_model
+
+ # Update the ssm / attn / mlp layers
+ # M: Mamba2, *: Attention, -: MLP
+ hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
+ self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
+ self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
+
+ def get_attn_layers(self):
+ hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
+ assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
+ return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+
+ self.gguf_writer.add_key_length(self.head_dim)
+ self.gguf_writer.add_value_length(self.head_dim)
+
+ # Set feed_forward_length
+ # NOTE: This will trigger an override warning. This is preferrable to
+ # duplicating all the parent logic
+ n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
+ self.gguf_writer.add_feed_forward_length([
+ n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
+ ])
+
+ def set_vocab(self):
+ super().set_vocab()
+
+ # The tokenizer _does_ add a BOS token (via post_processor type
+ # TemplateProcessing) but does not set add_bos_token to true in the
+ # config, so we need to explicitly override it here.
+ self.gguf_writer.add_add_bos_token(True)
+
+
@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_NEMOTRON_H:
+ {
+ ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
+ ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
+ ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
+ ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+ ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
+
+ // A layer is recurrent IFF the n_head_kv value is set to 0 and
+ // the n_ff value is set to 0
+ for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+ hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0);
+ }
+
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 56: type = LLM_TYPE_9B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_EXAONE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
}
} break;
+ case LLM_ARCH_NEMOTRON_H:
+ {
+ // mamba2 Mixer SSM params
+ // NOTE: int64_t for tensor dimensions
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t n_ssm_head = hparams.ssm_dt_rank;
+ const int64_t n_group = hparams.ssm_n_group;
+ const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
+
+ // embeddings
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ {
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+ // if output is NULL, init from the input tok embed, duplicated to allow offloading
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ // all blocks use the attn norm
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ if (hparams.is_recurrent(i)) {
+ // ssm layers
+ layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
+
+ layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
+ layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED);
+
+ layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
+
+ // no "weight" suffix for these
+ layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
+ layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
+
+ layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
+
+ // out_proj
+ layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
+ } else if (hparams.n_ff(i) == 0) {
+ // attention layers (with optional bias)
+ const int64_t n_head_i = hparams.n_head(i);
+ const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
+ const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED);
+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED);
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+ } else {
+ // mlp layers
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0);
+ layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+ layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED);
+ }
+ }
+ } break;
case LLM_ARCH_EXAONE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
arch == LLM_ARCH_JAMBA ||
arch == LLM_ARCH_FALCON_H1 ||
arch == LLM_ARCH_PLAMO2 ||
- arch == LLM_ARCH_GRANITE_HYBRID) {
+ arch == LLM_ARCH_GRANITE_HYBRID ||
+ arch == LLM_ARCH_NEMOTRON_H) {
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
}
};
+struct llm_build_nemotron_h : public llm_graph_context_mamba {
+ llm_build_nemotron_h(
+ const llama_model & model,
+ const llm_graph_params & params) :
+ llm_graph_context_mamba(params) {
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ auto * inp = build_inp_mem_hybrid();
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ if (hparams.is_recurrent(il)) {
+ // ssm layer //
+ cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il);
+ } else if (hparams.n_ff(il) == 0) {
+ // attention layer //
+ cur = build_attention_layer(cur, inp->get_attn(), model, n_embd_head, il);
+ } else {
+ cur = build_ffn_layer(cur, model, il);
+ }
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ // add residual
+ cur = ggml_add(ctx0, cur, inpSA);
+ cb(cur, "block_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+
+ ggml_tensor * build_attention_layer(
+ ggml_tensor * cur,
+ llm_graph_input_attn_kv * inp_attn,
+ const llama_model & model,
+ const int64_t n_embd_head,
+ const int il) {
+
+ // compute Q and K and (optionally) RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+ cur = build_attn(inp_attn,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+ cb(cur, "attn_out", il);
+ return cur;
+ }
+
+ ggml_tensor * build_ffn_layer(
+ ggml_tensor * cur,
+ const llama_model & model,
+ const int il) {
+
+ cur = build_ffn(cur,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
+ NULL, NULL, NULL,
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+ NULL,
+ LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
+ cb(cur, "ffn_out", il);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ return cur;
+ }
+};
+
struct llm_build_exaone : public llm_graph_context {
llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
cparams.n_seq_max,
nullptr);
} else if (llm_arch_is_hybrid(arch)) {
+
+ // The main difference between hybrid architectures is the
+ // layer filters, so pick the right one here
+ llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
+ llama_memory_hybrid::layer_filter_cb filter_recr = nullptr;
+ if (arch == LLM_ARCH_FALCON_H1) {
+ filter_attn = [&](int32_t) { return true; };
+ filter_recr = [&](int32_t) { return true; };
+ } else if (arch == LLM_ARCH_NEMOTRON_H) {
+ filter_attn = [&](int32_t il) {
+ return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
+ };
+ filter_recr = [&](int32_t il) {
+ return hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
+ };
+ }
+
const auto padding = llama_kv_cache::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
- /* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
- /* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
+ /* filter_attn */ std::move(filter_attn),
+ /* filter_recr */ std::move(filter_recr));
} else {
const auto padding = llama_kv_cache::get_padding(cparams);
{
llm = std::make_unique<llm_build_nemotron>(*this, params);
} break;
+ case LLM_ARCH_NEMOTRON_H:
+ {
+ llm = std::make_unique<llm_build_nemotron_h>(*this, params);
+ } break;
case LLM_ARCH_EXAONE:
{
llm = std::make_unique<llm_build_exaone>(*this, params);
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
case LLM_ARCH_WAVTOKENIZER_DEC:
+ case LLM_ARCH_NEMOTRON_H:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values