From: Gabe Goodhart Date: Fri, 29 Aug 2025 00:39:31 +0000 (-0600) Subject: nvidia nemotron nano v2 (nemotronh) (#15507) X-Git-Tag: upstream/0.0.6527~212 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=e8d99dd0b67f2ecc1e45fca8074a3a18c3e036d2;p=pkg%2Fggml%2Fsources%2Fllama.cpp nvidia nemotron nano v2 (nemotronh) (#15507) * feat: Add NEMOTRONH to python arch enum https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * feat: Add NEMOTRONH to c++ arch enum https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * feat: Add NEMOTRONH to llama-arch layer map https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * feat: First pass at conversion for nemotronh https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * feat: Add a verbose log for each tensor loaded This is really helpful for diagnosing mismatches between the expected and received tensors https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * feat: First (broken) pass at nemotronh model architecture It generates tokens, just not valid ones! https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * fix: Explicitly enable add_bos_token during conversion The `tokenizer.json`/`tokenizer_config.json` in the model are a bit contradictory. In the config, add_bos_token is set to False, but the tokenizer model itself has a post_processor that adds the BOS token via type: TemplateProcessing https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * fix: Use relu2 (LLM_FFN_RELU_SQR) for activation in FFN layers https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * fix: Only allocate attention cache for attention layers (not non-recurrent) https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * fix: Move residual add to after every block https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * fix: Use the correct norm tensor for the MLP blocks https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart * Nemotron-H: MLP gate cleanup (pass NULL for unused gate) This model does not use a gate in MLP blocks; pass NULLs for gate tensors to make intent clear and avoid unused-pointer noise. * SSM: respect ssm_dt_rank for dt_dim when provided Use GGUF-provided time_step_rank (ssm_dt_rank) to set dt_dim when > 0; fallback to max(64, n_embd/16). * fix: plamo2 - revert dt_dim to default (remove ssm_dt_rank usage) * Rename nemotronh to nemotron_h for consistency - Update architecture name from NEMOTRONH to NEMOTRON_H in constants.py - Change architecture string from 'nemotronh' to 'nemotron_h' in all files - Update enum LLM_ARCH_NEMOTRONH to LLM_ARCH_NEMOTRON_H - Update class name llm_build_nemotronh to llm_build_nemotron_h - Consistent naming with underscore convention (nemotron_h vs nemotronh) * feat: Support conversion for older NemotronH models https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart --------- Signed-off-by: Gabe Goodhart Co-authored-by: Maicon Domingues Co-authored-by: weatherman --- diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6c8a0340..df37c4a6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7546,9 +7546,13 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel): ] # n_group and d_inner are used during reshape_tensors for mamba2 - self.d_model = self.find_hparam(["hidden_size", "d_model"]) - self.n_group = self.find_hparam(["n_groups"]) - self.d_inner = self.find_hparam(["expand"]) * self.d_model + # NOTE: Explicitly include hparam prefix prefix for d_model to + # disambiguate with top-level head_dim + # NOTE 2: If needed for future models, this can be isolated in a method + # to separate the prefix setting and teh keys used + self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"]) + self.n_group = self.find_hparam(["n_groups", "num_groups"]) + self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model def get_attn_layers(self): # Explicit list of layer type names @@ -7609,12 +7613,12 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel): ## Mamba mixer params ## self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) - self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"])) + self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"])) self.gguf_writer.add_ssm_group_count(self.n_group) self.gguf_writer.add_ssm_inner_size(self.d_inner) # NOTE: The mamba_dt_rank is _not_ the right field for how this is used # in llama.cpp - self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) + self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"])) ## Attention params ## head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) @@ -7641,6 +7645,55 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel): Mamba2Model.set_vocab(self) +@ModelBase.register("NemotronHForCausalLM") +class NemotronHModel(GraniteHybridModel): + """Hybrid mamba2/attention model from NVIDIA""" + model_arch = gguf.MODEL_ARCH.NEMOTRON_H + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Save the top-level head_dim for later + self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim")) + assert self.head_dim is not None, "Could not find the attention head dim in config" + + # Don't use expand to calculate d_inner + self.d_inner = self.find_hparam(["num_heads"]) * self.d_model + + # Update the ssm / attn / mlp layers + # M: Mamba2, *: Attention, -: MLP + hybrid_override_pattern = self.hparams["hybrid_override_pattern"] + self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"] + self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"] + + def get_attn_layers(self): + hybrid_override_pattern = self.hparams["hybrid_override_pattern"] + assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!" + return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_key_length(self.head_dim) + self.gguf_writer.add_value_length(self.head_dim) + + # Set feed_forward_length + # NOTE: This will trigger an override warning. This is preferrable to + # duplicating all the parent logic + n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"]) + self.gguf_writer.add_feed_forward_length([ + n_ff if i in self._mlp_layers else 0 for i in range(self.block_count) + ]) + + def set_vocab(self): + super().set_vocab() + + # The tokenizer _does_ add a BOS token (via post_processor type + # TemplateProcessing) but does not set add_bos_token to true in the + # config, so we need to explicitly override it here. + self.gguf_writer.add_add_bos_token(True) + + @ModelBase.register("BailingMoeForCausalLM") class BailingMoeModel(TextModel): model_arch = gguf.MODEL_ARCH.BAILINGMOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a581f960..6156d35c 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -367,6 +367,7 @@ class MODEL_ARCH(IntEnum): T5ENCODER = auto() JAIS = auto() NEMOTRON = auto() + NEMOTRON_H = auto() EXAONE = auto() EXAONE4 = auto() GRANITE = auto() @@ -700,6 +701,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.NEMOTRON_H: "nemotron_h", MODEL_ARCH.EXAONE: "exaone", MODEL_ARCH.EXAONE4: "exaone4", MODEL_ARCH.GRANITE: "granite", @@ -2297,6 +2299,25 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.NEMOTRON_H: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.EXAONE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index abb21fa8..497f4880 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -191,6 +191,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.q_proj", # llama4 "model.transformer.blocks.{bid}.q_proj", # llada "layers.{bid}.self_attn.q_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.q_proj", # nemotron-h ), # Attention key @@ -209,6 +210,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.k_proj", # llama4 "model.transformer.blocks.{bid}.k_proj", # llada "layers.{bid}.self_attn.k_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.k_proj", # nemotron-h ), # Attention value @@ -226,6 +228,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.v_proj", # llama4 "model.transformer.blocks.{bid}.v_proj", # llada "layers.{bid}.self_attn.v_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.v_proj", # nemotron-h ), # Attention output @@ -260,6 +263,7 @@ class TensorNameMap: "transformer_encoder.{bid}.wo", # neobert "model.transformer.blocks.{bid}.attn_out", # llada "layers.{bid}.self_attn.o_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.o_proj", # nemotron-h ), # Attention output norm @@ -387,6 +391,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.up", # smallthinker "model.transformer.blocks.{bid}.up_proj", # llada "layers.{bid}.mlp.up_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.up_proj", # nemotron-h ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -480,6 +485,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.down", # smallthinker "model.transformer.blocks.{bid}.ff_out", # llada "layers.{bid}.mlp.down_proj", # qwen3-embedding + "backbone.layers.{bid}.mixer.down_proj", # nemotron-h ), MODEL_TENSOR.FFN_DOWN_EXP: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a61dc177..d5c8477f 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -69,6 +69,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, @@ -1550,6 +1551,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_NEMOTRON_H, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + // mamba(2) ssm layers + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + // attention layers + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // dense FFN + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_EXAONE, { @@ -2355,6 +2381,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: + case LLM_ARCH_NEMOTRON_H: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 94b0bef7..86c11969 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -73,6 +73,7 @@ enum llm_arch { LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, + LLM_ARCH_NEMOTRON_H, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index f71c40f8..8182a9ad 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -788,6 +788,7 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri } struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags) { + LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str()); const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 30974a72..f3e0e9ac 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1570,6 +1570,27 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_NEMOTRON_H: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // A layer is recurrent IFF the n_head_kv value is set to 0 and + // the n_ff value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 56: type = LLM_TYPE_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_EXAONE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -4688,6 +4709,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_NEMOTRON_H: + { + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // all blocks use the attn norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.is_recurrent(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else if (hparams.n_ff(i) == 0) { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_EXAONE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5862,7 +5952,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_JAMBA || arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID) { + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -14129,6 +14220,138 @@ struct llm_build_nemotron : public llm_graph_context { } }; +struct llm_build_nemotron_h : public llm_graph_context_mamba { + llm_build_nemotron_h( + const llama_model & model, + const llm_graph_params & params) : + llm_graph_context_mamba(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (hparams.is_recurrent(il)) { + // ssm layer // + cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); + } else if (hparams.n_ff(il) == 0) { + // attention layer // + cur = build_attention_layer(cur, inp->get_attn(), model, n_embd_head, il); + } else { + cur = build_ffn_layer(cur, model, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // add residual + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "block_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + ggml_tensor * build_attention_layer( + ggml_tensor * cur, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il) { + + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + return cur; + } + + ggml_tensor * build_ffn_layer( + ggml_tensor * cur, + const llama_model & model, + const int il) { + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + return cur; + } +}; + struct llm_build_exaone : public llm_graph_context { llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -18277,6 +18500,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, nullptr); } else if (llm_arch_is_hybrid(arch)) { + + // The main difference between hybrid architectures is the + // layer filters, so pick the right one here + llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; + llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; + if (arch == LLM_ARCH_FALCON_H1) { + filter_attn = [&](int32_t) { return true; }; + filter_recr = [&](int32_t) { return true; }; + } else if (arch == LLM_ARCH_NEMOTRON_H) { + filter_attn = [&](int32_t il) { + return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + filter_recr = [&](int32_t il) { + return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + } + const auto padding = llama_kv_cache::get_padding(cparams); cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); @@ -18296,8 +18536,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, - /* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr, - /* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr); + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); } else { const auto padding = llama_kv_cache::get_padding(cparams); @@ -18625,6 +18865,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_NEMOTRON_H: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_EXAONE: { llm = std::make_unique(*this, params); @@ -18860,6 +19104,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_NEMOTRON_H: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values