]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support T5 models with unequal number of encoder-decoder layers (#15909)
authorJie Fu (傅杰) <redacted>
Wed, 10 Sep 2025 18:51:51 +0000 (02:51 +0800)
committerGitHub <redacted>
Wed, 10 Sep 2025 18:51:51 +0000 (20:51 +0200)
* Extend the support of T5 models with different encoder-decoder layers

Signed-off-by: Jie Fu <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/constants.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-arch.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-arch.h

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-hparams.h

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Rename n_dec_layer --> dec_n_layer

Signed-off-by: Jie Fu <redacted>
* Adapt to cases when dec_n_layer > n_layer

Signed-off-by: Jie Fu <redacted>
---------

Signed-off-by: Jie Fu <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-hparams.h
src/llama-model.cpp

index 62a546ee2220147dabe43834b75d5bdc0bde7111..bbc21813f81ca171c9b01b509f200dc28e3cf575 100755 (executable)
@@ -6701,6 +6701,8 @@ class T5Model(TextModel):
         self.gguf_writer.add_embedding_length(self.hparams["d_model"])
         self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
         self.gguf_writer.add_block_count(self.hparams["num_layers"])
+        if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
+            self.gguf_writer.add_decoder_block_count(dec_n_layer)
         self.gguf_writer.add_head_count(self.hparams["num_heads"])
         self.gguf_writer.add_key_length(self.hparams["d_kv"])
         self.gguf_writer.add_value_length(self.hparams["d_kv"])
index b8ac394580b1f99c34d876be8fd5671fd2aa259a..1e88b6505bae06be2f592160733ab38b967c1eb0 100644 (file)
@@ -109,6 +109,7 @@ class Keys:
         POOLING_TYPE                      = "{arch}.pooling_type"
         LOGIT_SCALE                       = "{arch}.logit_scale"
         DECODER_START_TOKEN_ID            = "{arch}.decoder_start_token_id"
+        DECODER_BLOCK_COUNT               = "{arch}.decoder_block_count"
         ATTN_LOGIT_SOFTCAPPING            = "{arch}.attn_logit_softcapping"
         FINAL_LOGIT_SOFTCAPPING           = "{arch}.final_logit_softcapping"
         SWIN_NORM                         = "{arch}.swin_norm"
index a6cc8a931eb27783a22dfaae6fec38b664fe574c..7ff12f7f5709ddcdb8cc1b616e1384e22386a53c 100644 (file)
@@ -676,6 +676,9 @@ class GGUFWriter:
     def add_decoder_start_token_id(self, id: int) -> None:
         self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
 
+    def add_decoder_block_count(self, value: int) -> None:
+        self.add_uint32(Keys.LLM.DECODER_BLOCK_COUNT.format(arch=self.arch), value)
+
     def add_embedding_length_per_layer_input(self, value: int) -> None:
         self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)
 
index 77b2fecf18fb856041a0c29571a170bbb35829ad..81f9746818d4aa520f136f7c586f76011b7762ca 100644 (file)
@@ -137,6 +137,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      },
     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
+    { LLM_KV_DECODER_BLOCK_COUNT,               "%s.decoder_block_count"               },
     { LLM_KV_ATTN_LOGIT_SOFTCAPPING,            "%s.attn_logit_softcapping"            },
     { LLM_KV_FINAL_LOGIT_SOFTCAPPING,           "%s.final_logit_softcapping"           },
     { LLM_KV_SWIN_NORM,                         "%s.swin_norm"                         },
index 21ab47bd7af2a9288d386b5bb79e36ad95f406ae..6ee3707dcfbf64deea1c2d112803d24b2bc3eda7 100644 (file)
@@ -141,6 +141,7 @@ enum llm_kv {
     LLM_KV_POOLING_TYPE,
     LLM_KV_LOGIT_SCALE,
     LLM_KV_DECODER_START_TOKEN_ID,
+    LLM_KV_DECODER_BLOCK_COUNT,
     LLM_KV_ATTN_LOGIT_SOFTCAPPING,
     LLM_KV_FINAL_LOGIT_SOFTCAPPING,
     LLM_KV_SWIN_NORM,
index 89f5c7ab65dce2ddf2c69ef871181a421a5efbc7..4dca2ca41d095af5f7c2b66602d105e627f4fe9e 100644 (file)
@@ -159,6 +159,7 @@ struct llama_hparams {
     // needed by encoder-decoder models (e.g. T5, FLAN-T5)
     // ref: https://github.com/ggerganov/llama.cpp/pull/8141
     llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
+    uint32_t    dec_n_layer        = 0;
 
     enum llama_pooling_type      pooling_type            = LLAMA_POOLING_TYPE_NONE;
     enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
index b9e4634a7061c81f1eaf0b7c7a2db6e6a495e0d8..818b209641a5a2c4a7f9da39065c2101d14dadc3 100644 (file)
@@ -1542,6 +1542,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.dec_start_token_id = dec_start_token_id;
                 }
 
+                hparams.dec_n_layer = hparams.n_layer;
+                ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false);
+
                 switch (hparams.n_layer) {
                     case 6:  type = LLM_TYPE_60M;  break; // t5-small
                     case 8:  type = LLM_TYPE_80M;  break; // flan-t5-small
@@ -4414,6 +4417,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                     }
 
+                    // n_layer:     number of encoder_layers
+                    // dec_n_layer: number of decoder_layers
+                    const int dec_n_layer = hparams.dec_n_layer;
+                    if (dec_n_layer > n_layer) {
+                        layers.resize(dec_n_layer);
+                    }
+
+                    // load encoder layers
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
 
@@ -4429,6 +4440,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
                         layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                         layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+
+                    // load decoder layers
+                    for (int i = 0; i < dec_n_layer; ++i) {
+                        auto & layer = layers[i];
 
                         layer.attn_norm  = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd}, 0);
                         layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
@@ -13509,7 +13525,9 @@ struct llm_build_t5_dec : public llm_graph_context {
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-        for (int il = 0; il < n_layer; ++il) {
+        const int64_t dec_n_layer = hparams.dec_n_layer;
+
+        for (int il = 0; il < dec_n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
 
             // norm
@@ -13600,7 +13618,7 @@ struct llm_build_t5_dec : public llm_graph_context {
                 //cb(cur, "kqv_out", il);
             }
 
-            if (il == n_layer - 1 && inp_out_ids) {
+            if (il == dec_n_layer - 1 && inp_out_ids) {
                 cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
                 inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
             }
@@ -13621,8 +13639,8 @@ struct llm_build_t5_dec : public llm_graph_context {
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
                         NULL,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
+                        model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU,
+                        model.layers[il].ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ,
                         il);
                 cb(cur, "ffn_out", il);
             }