]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add Nemotron/Minitron GGUF Conversion & Inference Support (#8922)
authorYoshi Suhara <redacted>
Fri, 16 Aug 2024 02:23:33 +0000 (19:23 -0700)
committerGitHub <redacted>
Fri, 16 Aug 2024 02:23:33 +0000 (04:23 +0200)
* Add nemotron GGUF conversion & inference support

* Fix formatting issues

* Remove unnecessary write_tensors()

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <redacted>
* Update src/llama.cpp

Co-authored-by: compilade <redacted>
* Address comments by @compilade

* Replace ggml_mul_mat()->llm_build_lora_mm()

* Remove mutable variable

* Use  for bias tensors

* Cover corner case for role_scaling not in config.json

---------

Co-authored-by: compilade <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama.cpp

index 41063d94b684e47dc607003304bfcdd7c1f3baaa..8f0e8680f6f911f7485276e6c2f7864c1678bfd6 100755 (executable)
@@ -3740,6 +3740,47 @@ class ChatGLMModel(Model):
         name = name.removeprefix("transformer.")
         return [(self.map_tensor_name(name), data_torch)]
 
+
+@Model.register("NemotronForCausalLM")
+class NemotronModel(Model):
+    model_arch = gguf.MODEL_ARCH.NEMOTRON
+
+    def set_vocab(self):
+        self._set_vocab_sentencepiece()
+        self.gguf_writer.add_pad_token_id(0)
+        self.gguf_writer.add_unk_token_id(1)
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        hparams = self.hparams
+        self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+
+        f_norm_eps = self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon", "norm_eps"])
+        self.gguf_writer.add_layer_norm_eps(f_norm_eps)
+
+        # * Partial RoPE
+        rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"])
+        n_embd = self.find_hparam(["hidden_size", "n_embd"])
+        n_head = self.find_hparam(["num_attention_heads", "n_head"])
+        self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
+
+        # * RopeScaling for Nemotron
+        if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None:
+            self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+        else:
+            self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+            self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"])
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side
+        #   model.layers.{l}.input_layernorm.weight
+        #   model.layers.{l}.post_attention_layernorm.weight
+        #   model.norm.weight
+        if name.endswith("norm.weight"):
+            data_torch = data_torch + 1
+
+        return [(self.map_tensor_name(name), data_torch)]
+
 ###### CONVERSION LOGIC ######
 
 
index f63ec450a4e09aa137d31fdf23d0ec36c1e7fec1..63d51d87281d879a58b2d5e01012a54d7d67db9e 100644 (file)
@@ -219,6 +219,7 @@ class MODEL_ARCH(IntEnum):
     T5           = auto()
     T5ENCODER    = auto()
     JAIS         = auto()
+    NEMOTRON     = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -347,6 +348,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.T5:             "t5",
     MODEL_ARCH.T5ENCODER:      "t5encoder",
     MODEL_ARCH.JAIS:           "jais",
+    MODEL_ARCH.NEMOTRON:       "nemotron",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -1065,6 +1067,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_GATE,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.NEMOTRON: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     # TODO
 }
 
@@ -1105,6 +1122,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
     MODEL_ARCH.CHATGLM: [
         MODEL_TENSOR.ROPE_FREQS,
     ],
+    MODEL_ARCH.NEMOTRON: [
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+    ],
 }
 
 #
index 9aa2209e2f69f8f7ad3b6fbcc3b78e4dbf864349..a07f7f582d28940bf61cf74ec402fbd33f84ed6d 100644 (file)
@@ -13,7 +13,7 @@ class TensorNameMap:
             "transformer.wte",                           # gpt2 gpt-j mpt refact qwen dbrx jais
             "transformer.word_embeddings",               # falcon
             "word_embeddings",                           # bloom
-            "model.embed_tokens",                        # llama-hf
+            "model.embed_tokens",                        # llama-hf nemotron
             "tok_embeddings",                            # llama-pth
             "embeddings.word_embeddings",                # bert nomic-bert
             "language_model.embedding.word_embeddings",  # persimmon
@@ -52,7 +52,7 @@ class TensorNameMap:
         # Output
         MODEL_TENSOR.OUTPUT: (
             "embed_out",                 # gptneox
-            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
+            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron
             "output",                    # llama-pth bloom internlm2
             "word_embeddings_for_head",  # persimmon
             "lm_head.linear",            # phi2
@@ -75,6 +75,7 @@ class TensorNameMap:
             "transformer.rms_norm",                    # Grok
             "encoder.final_layernorm",                 # chatglm
             "transformer.norm",                        # openelm
+            "model.norm",                              # nemotron
         ),
 
         # Rope frequencies
@@ -93,7 +94,7 @@ class TensorNameMap:
             "transformer.h.{bid}.input_layernorm",                  # falcon7b
             "h.{bid}.input_layernorm",                              # bloom
             "transformer.h.{bid}.ln_mlp",                           # falcon40b
-            "model.layers.{bid}.input_layernorm",                   # llama-hf
+            "model.layers.{bid}.input_layernorm",                   # llama-hf nemotron
             "layers.{bid}.attention_norm",                          # llama-pth
             "language_model.encoder.layers.{bid}.input_layernorm",  # persimmon
             "model.layers.{bid}.ln1",                               # yi
@@ -135,7 +136,7 @@ class TensorNameMap:
 
         # Attention query
         MODEL_TENSOR.ATTN_Q: (
-            "model.layers.{bid}.self_attn.q_proj",                       # llama-hf
+            "model.layers.{bid}.self_attn.q_proj",                       # llama-hf nemotron
             "layers.{bid}.attention.wq",                                 # llama-pth
             "encoder.layer.{bid}.attention.self.query",                  # bert
             "transformer.h.{bid}.attn.q_proj",                           # gpt-j
@@ -146,7 +147,7 @@ class TensorNameMap:
 
         # Attention key
         MODEL_TENSOR.ATTN_K: (
-            "model.layers.{bid}.self_attn.k_proj",                     # llama-hf
+            "model.layers.{bid}.self_attn.k_proj",                     # llama-hf nemotron
             "layers.{bid}.attention.wk",                               # llama-pth
             "encoder.layer.{bid}.attention.self.key",                  # bert
             "transformer.h.{bid}.attn.k_proj",                         # gpt-j
@@ -158,7 +159,7 @@ class TensorNameMap:
 
         # Attention value
         MODEL_TENSOR.ATTN_V: (
-            "model.layers.{bid}.self_attn.v_proj",                       # llama-hf
+            "model.layers.{bid}.self_attn.v_proj",                       # llama-hf nemotron
             "layers.{bid}.attention.wv",                                 # llama-pth
             "encoder.layer.{bid}.attention.self.value",                  # bert
             "transformer.h.{bid}.attn.v_proj",                           # gpt-j
@@ -175,7 +176,7 @@ class TensorNameMap:
             "transformer.blocks.{bid}.attn.out_proj",                       # mpt
             "transformer.h.{bid}.self_attention.dense",                     # falcon
             "h.{bid}.self_attention.dense",                                 # bloom
-            "model.layers.{bid}.self_attn.o_proj",                          # llama-hf
+            "model.layers.{bid}.self_attn.o_proj",                          # llama-hf nemotron
             "layers.{bid}.attention.wo",                                    # llama-pth
             "encoder.layer.{bid}.attention.output.dense",                   # bert
             "transformer.h.{bid}.attn.out_proj",                            # gpt-j
@@ -218,7 +219,7 @@ class TensorNameMap:
             "transformer.h.{bid}.ln_2",                                      # gpt2 refact qwen jais
             "h.{bid}.post_attention_layernorm",                              # bloom
             "transformer.blocks.{bid}.norm_2",                               # mpt
-            "model.layers.{bid}.post_attention_layernorm",                   # llama-hf
+            "model.layers.{bid}.post_attention_layernorm",                   # llama-hf nemotron
             "layers.{bid}.ffn_norm",                                         # llama-pth
             "language_model.encoder.layers.{bid}.post_attention_layernorm",  # persimmon
             "model.layers.{bid}.ln2",                                        # yi
@@ -258,7 +259,7 @@ class TensorNameMap:
             "transformer.blocks.{bid}.ffn.up_proj",                   # mpt
             "transformer.h.{bid}.mlp.dense_h_to_4h",                  # falcon
             "h.{bid}.mlp.dense_h_to_4h",                              # bloom
-            "model.layers.{bid}.mlp.up_proj",                         # llama-hf refact
+            "model.layers.{bid}.mlp.up_proj",                         # llama-hf refact nemotron
             "layers.{bid}.feed_forward.w3",                           # llama-pth
             "encoder.layer.{bid}.intermediate.dense",                 # bert
             "transformer.h.{bid}.mlp.fc_in",                          # gpt-j
@@ -329,7 +330,7 @@ class TensorNameMap:
             "transformer.blocks.{bid}.ffn.down_proj",                 # mpt
             "transformer.h.{bid}.mlp.dense_4h_to_h",                  # falcon
             "h.{bid}.mlp.dense_4h_to_h",                              # bloom
-            "model.layers.{bid}.mlp.down_proj",                       # llama-hf
+            "model.layers.{bid}.mlp.down_proj",                       # llama-hf nemotron
             "layers.{bid}.feed_forward.w2",                           # llama-pth
             "encoder.layer.{bid}.output.dense",                       # bert
             "transformer.h.{bid}.mlp.fc_out",                         # gpt-j
index ee36de977cdc2cdf9d3b4ada8294a22ce3181873..1128bf53b596835bf8994e8671f84c56c0bae0bb 100644 (file)
@@ -210,6 +210,7 @@ enum llm_arch {
     LLM_ARCH_T5,
     LLM_ARCH_T5ENCODER,
     LLM_ARCH_JAIS,
+    LLM_ARCH_NEMOTRON,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -255,6 +256,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_T5,              "t5"           },
     { LLM_ARCH_T5ENCODER,       "t5encoder"    },
     { LLM_ARCH_JAIS,            "jais"         },
+    { LLM_ARCH_NEMOTRON,        "nemotron"     },
     { LLM_ARCH_UNKNOWN,         "(unknown)"    },
 };
 
@@ -1296,6 +1298,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
         },
     },
+    {
+        LLM_ARCH_NEMOTRON,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -5235,6 +5255,14 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_NEMOTRON:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_4B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -7568,6 +7596,48 @@ static bool llm_load_tensors(
                         layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
                     }
                 } break;
+            case LLM_ARCH_NEMOTRON:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,   tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split,  tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        // optional bias tensors
+                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
+
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+
+                        // optional MLP bias
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -8254,7 +8324,7 @@ static struct ggml_tensor * llm_build_kqv(
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON) {
             // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
             // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
             ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -13755,6 +13825,128 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_nemotron() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        //GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm,
+                    model.layers[il].ffn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                    NULL,                      NULL,                        NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
+                    LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -14010,6 +14202,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_jais();
             } break;
+        case LLM_ARCH_NEMOTRON:
+            {
+                result = llm.build_nemotron();
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -17080,6 +17276,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX:
         case LLM_ARCH_CODESHELL:
+        case LLM_ARCH_NEMOTRON:
             return LLAMA_ROPE_TYPE_NEOX;
 
         // all model arches should be listed explicitly here