]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
models : move the token embedding norms to the first layer (#20943)
authorGeorgi Gerganov <redacted>
Tue, 24 Mar 2026 15:00:30 +0000 (17:00 +0200)
committerGitHub <redacted>
Tue, 24 Mar 2026 15:00:30 +0000 (17:00 +0200)
* models : move the token embedding norms to the first layer

* cont : fix LLM_TENSOR_CONV1D + fix il indexing

src/llama-arch.cpp
src/llama-model.cpp
src/models/bert.cpp
src/models/bloom.cpp
src/models/modern-bert.cpp
src/models/rwkv6.cpp
src/models/rwkv7.cpp
src/models/wavtokenizer-dec.cpp

index 84dc6d8f1b6543088ec096fa82084bd5a9c84846..7019766b81c4cdb99d832a634e9cf5786f51d986 100644 (file)
@@ -2564,7 +2564,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_TOKEN_EMBD,                 {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
     {LLM_TENSOR_POS_EMBD,                   {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
     {LLM_TENSOR_TOKEN_TYPES,                {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
-    {LLM_TENSOR_TOKEN_EMBD_NORM,            {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
+    {LLM_TENSOR_TOKEN_EMBD_NORM,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},  // do the norms on the first layer (not the input layer)
     {LLM_TENSOR_OUTPUT,                     {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_CLS,                        {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_CLS_OUT,                    {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
@@ -2725,7 +2725,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_LAUREL_POST_NORM,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     // this tensor is loaded for T5, but never used
     {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
-    {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_INPUT,     GGML_OP_IM2COL}},
+    {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
     {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_POS_NET_NORM2,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
index f8caad2889b0440415a929307c4cc6c59273a002..9f9098f405e1ad67ceb3e06abf163025b079b4f5 100644 (file)
@@ -3217,8 +3217,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {hparams.n_cls_out},         TENSOR_NOT_REQUIRED);
                     }
 
-                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias",   0), {n_embd}, 0);
 
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
@@ -3265,7 +3265,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             case LLM_ARCH_MODERN_BERT:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
 
                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
 
@@ -3348,8 +3348,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     tok_embd  = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0); // word_embeddings
                     type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings
 
-                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm
-                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0); //LayerNorm bias
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias",   0), {n_embd}, 0); // LayerNorm bias
 
                     cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED);
                     cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {1},         TENSOR_NOT_REQUIRED);
@@ -3400,8 +3400,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             case LLM_ARCH_BLOOM:
                 {
                     tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
-                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias",   0), {n_embd}, 0);
 
                     // output
                     output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
@@ -5780,8 +5780,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // Block 0, LN0
-                    tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias",   0), {n_embd}, 0);
 
                     // output
                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
@@ -5895,8 +5895,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // Block 0, LN0
-                    tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias",   0), {n_embd}, 0);
 
                     // output
                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
@@ -6067,8 +6067,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0);
 
-                    conv1d   = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0);
-                    conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"),   {1, hparams.posnet.n_embd}, 0);
+                    conv1d   = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0);
+                    conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias",   0), {1, hparams.posnet.n_embd}, 0);
 
                     // posnet
                     {
@@ -6133,8 +6133,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                     GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);
 
-                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0);
-                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {hparams.posnet.n_embd}, 0);
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias",   0), {hparams.posnet.n_embd}, 0);
 
                     // convnext
                     {
index 8733179141885ce7f1a8f9032d45a13a93cd743e..6ab8c13685823c927dc8c4f2affd8f094286136e 100644 (file)
@@ -28,8 +28,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params
     cb(inpL, "inp_embd", -1);
 
     // embed layer norm
-    inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
-    cb(inpL, "inp_norm", -1);
+    inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0);
+    cb(inpL, "inp_norm", 0);
 
     auto * inp_attn = build_attn_inp_no_cache();
 
index b1c19bb58a2e07dbe5db557b6d54110bbf6b3378..aa4b939b711d7d75f0ebe1b6eb51278a88d87c84 100644 (file)
@@ -16,8 +16,8 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para
     inpL = build_norm(inpL,
             model.tok_norm,
             model.tok_norm_b,
-            LLM_NORM, -1);
-    cb(inpL, "inp_norm", -1);
+            LLM_NORM, 0);
+    cb(inpL, "inp_norm", 0);
 
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
index 26020584c6dbff95a7dc0da2ba58eb067fc2cae6..7662321093435f93250062ef1c1a867ec26f6482 100644 (file)
@@ -15,8 +15,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll
     cb(inpL, "inp_embd", -1);
 
     // embed layer norm
-    inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1);
-    cb(inpL, "inp_norm", -1);
+    inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, 0);
+    cb(inpL, "inp_norm", 0);
 
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
index 15453fbf50f518d589a30cacfe3ed7311a828ef8..032b219d6cbf6c5ea6fab44967f632c811d73b0c 100644 (file)
@@ -8,7 +8,7 @@ llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_para
     ggml_tensor * inpL;
 
     inpL = build_inp_embd(model.tok_embd);
-    inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
+    inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0);
 
     auto * rs_inp = build_rs_inp();
 
index 5caf6553dfe1a34f20f7f05dfd32c27d9e514ae0..16ffa6901b958b2a6ec4c84e572571f4bb412f26 100644 (file)
@@ -9,7 +9,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para
     ggml_tensor * v_first = nullptr;
 
     inpL = build_inp_embd(model.tok_embd);
-    inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
+    inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0);
 
     auto * rs_inp = build_rs_inp();
 
index 537a0d41248b6fd23666644611e6ab51b8523b88..a7776d9cdc97ea10e63f7f17cf9d2ea3509608bd 100644 (file)
@@ -93,7 +93,7 @@ llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model
     cur = build_norm(cur,
             model.tok_norm,
             model.tok_norm_b,
-            LLM_NORM, -1);
+            LLM_NORM, 0);
 
     cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));