]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add support for encoder-only T5 models (#8900)
authorfairydreaming <redacted>
Sat, 10 Aug 2024 09:43:26 +0000 (11:43 +0200)
committerGitHub <redacted>
Sat, 10 Aug 2024 09:43:26 +0000 (11:43 +0200)
* gguf-py : add T5ENCODER model architecture

* common : call llama_decode() during warmup only if the model has decoder

* convert-hf : add T5EncoderModel

* llama : add llama_model_has_decoder() API function

* llama : split build_t5() into build_t5_encoder() and build_t5_decoder()

* llama : add support for LLM_ARCH_T5ENCODER

* llama-embedding : add support for LLAMA_POOLING_TYPE_NONE

* llama-embedding : add support for encoder-only models

---------

Co-authored-by: Stanisław Szymczyk <redacted>
common/common.cpp
convert_hf_to_gguf.py
examples/embedding/embedding.cpp
gguf-py/gguf/constants.py
include/llama.h
src/llama.cpp

index 560e20d080d0f16995409b27aeb9908847d4f792..d3d896115ae36b9446ef1bdbe35aa8f8231571ce 100644 (file)
@@ -2156,7 +2156,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
             tmp.clear();
             tmp.push_back(decoder_start_token_id);
         }
-        llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
+        if (llama_model_has_decoder(model)) {
+            llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
+        }
         llama_kv_cache_clear(lctx);
         llama_synchronize(lctx);
         llama_reset_timings(lctx);
index 7136db440644b4a6e6c81c3e28694039e2a13cf5..550dd5cfda99f083ee0c2c24a6d288b31be8296f 100755 (executable)
@@ -3324,6 +3324,145 @@ class T5Model(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
+@Model.register("T5EncoderModel")
+class T5EncoderModel(Model):
+    model_arch = gguf.MODEL_ARCH.T5ENCODER
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.shared_token_embeddings_found = False
+
+    def set_vocab(self):
+        # to avoid TypeError: Descriptors cannot be created directly
+        # exception when importing sentencepiece_model_pb2
+        os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+        from sentencepiece import SentencePieceProcessor
+        from sentencepiece import sentencepiece_model_pb2 as model
+
+        tokenizer_path = self.dir_model / 'tokenizer.model'
+
+        # many older models use spiece.model tokenizer model filename
+        if not tokenizer_path.is_file():
+            tokenizer_path = self.dir_model / 'spiece.model'
+
+        if not tokenizer_path.is_file():
+            raise FileNotFoundError(f"File not found: {tokenizer_path}")
+
+        sentencepiece_model = model.ModelProto()  # pyright: ignore[reportAttributeAccessIssue]
+        sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
+
+        # some models like Pile-T5 family use BPE tokenizer instead of Unigram
+        if sentencepiece_model.trainer_spec.model_type == 2:  # BPE
+            # assure the tokenizer model file name is correct
+            assert tokenizer_path.name == 'tokenizer.model'
+            return self._set_vocab_sentencepiece()
+        else:
+            assert sentencepiece_model.trainer_spec.model_type == 1  # UNIGRAM
+
+        add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
+        remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
+        precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
+
+        tokenizer = SentencePieceProcessor()
+        tokenizer.LoadFromFile(str(tokenizer_path))
+
+        vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
+
+        tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
+        scores: list[float] = [-10000.0] * vocab_size
+        toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
+
+        for token_id in range(tokenizer.vocab_size()):
+            piece = tokenizer.IdToPiece(token_id)
+            text = piece.encode("utf-8")
+            score = tokenizer.GetScore(token_id)
+
+            toktype = SentencePieceTokenTypes.NORMAL
+            if tokenizer.IsUnknown(token_id):
+                toktype = SentencePieceTokenTypes.UNKNOWN
+            elif tokenizer.IsControl(token_id):
+                toktype = SentencePieceTokenTypes.CONTROL
+            elif tokenizer.IsUnused(token_id):
+                toktype = SentencePieceTokenTypes.UNUSED
+            elif tokenizer.IsByte(token_id):
+                toktype = SentencePieceTokenTypes.BYTE
+
+            tokens[token_id] = text
+            scores[token_id] = score
+            toktypes[token_id] = toktype
+
+        added_tokens_file = self.dir_model / 'added_tokens.json'
+        if added_tokens_file.is_file():
+            with open(added_tokens_file, "r", encoding="utf-8") as f:
+                added_tokens_json = json.load(f)
+                for key in added_tokens_json:
+                    token_id = added_tokens_json[key]
+                    if token_id >= vocab_size:
+                        logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
+                        continue
+
+                    tokens[token_id] = key.encode("utf-8")
+                    scores[token_id] = -1000.0
+                    toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
+
+        if vocab_size > len(tokens):
+            pad_count = vocab_size - len(tokens)
+            logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
+            for i in range(1, pad_count + 1):
+                tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
+                scores.append(-1000.0)
+                toktypes.append(SentencePieceTokenTypes.UNUSED)
+
+        self.gguf_writer.add_tokenizer_model("t5")
+        self.gguf_writer.add_tokenizer_pre("default")
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_scores(scores)
+        self.gguf_writer.add_token_types(toktypes)
+        self.gguf_writer.add_add_space_prefix(add_prefix)
+        self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)
+        if precompiled_charsmap:
+            self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap)
+
+        special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
+        special_vocab.add_to_gguf(self.gguf_writer)
+
+        self.gguf_writer.add_add_bos_token(False)
+        self.gguf_writer.add_add_eos_token(True)
+
+    def set_gguf_parameters(self):
+        if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
+            logger.warning("Couldn't find context length in config.json, assuming default value of 512")
+            n_ctx = 512
+        self.gguf_writer.add_context_length(n_ctx)
+        self.gguf_writer.add_embedding_length(self.hparams["d_model"])
+        self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
+        self.gguf_writer.add_block_count(self.hparams["num_layers"])
+        self.gguf_writer.add_head_count(self.hparams["num_heads"])
+        self.gguf_writer.add_key_length(self.hparams["d_kv"])
+        self.gguf_writer.add_value_length(self.hparams["d_kv"])
+        self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
+        self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"])
+        self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
+        self.gguf_writer.add_file_type(self.ftype)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
+        # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
+        # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
+        # and decoder and ignore the remaining ones.
+        if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
+            if not self.shared_token_embeddings_found:
+                name = "shared.weight"
+                self.shared_token_embeddings_found = True
+            else:
+                logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
+                return []
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+
 @Model.register("JAISLMHeadModel")
 class JaisModel(Model):
     model_arch = gguf.MODEL_ARCH.JAIS
index cd7b448a619fa3e88d0c93296b1a65d5dfe798fe..b05aa006e7da51fcce9bde0984a69c8418265761 100644 (file)
@@ -31,13 +31,24 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
 }
 
 static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
+    const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
+    const struct llama_model * model = llama_get_model(ctx);
+
     // clear previous kv_cache values (irrelevant for embeddings)
     llama_kv_cache_clear(ctx);
 
     // run model
     fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
-    if (llama_decode(ctx, batch) < 0) {
-        fprintf(stderr, "%s : failed to decode\n", __func__);
+    if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
+        // encoder-only model
+        if (llama_encode(ctx, batch) < 0) {
+            fprintf(stderr, "%s : failed to encode\n", __func__);
+        }
+    } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
+        // decoder-only model
+        if (llama_decode(ctx, batch) < 0) {
+            fprintf(stderr, "%s : failed to decode\n", __func__);
+        }
     }
 
     for (int i = 0; i < batch.n_tokens; i++) {
@@ -45,11 +56,22 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
             continue;
         }
 
-        // try to get sequence embeddings - supported only when pooling_type is not NONE
-        const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
-        GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
+        const float * embd = nullptr;
+        int embd_pos = 0;
+
+        if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
+            // try to get token embeddings
+            embd = llama_get_embeddings_ith(ctx, i);
+            embd_pos = i;
+            GGML_ASSERT(embd != NULL && "failed to get token embeddings");
+        } else {
+            // try to get sequence embeddings - supported only when pooling_type is not NONE
+            embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+            embd_pos = batch.seq_id[i][0];
+            GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
+        }
 
-        float * out = output + batch.seq_id[i][0] * n_embd;
+        float * out = output + embd_pos * n_embd;
         llama_embd_normalize(embd, out, n_embd, embd_norm);
     }
 }
@@ -93,8 +115,9 @@ int main(int argc, char ** argv) {
     const int n_ctx = llama_n_ctx(ctx);
 
     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
-    if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
-        fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
+
+    if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
+        fprintf(stderr, "%s: error: computing embeddings in encoder-decoder models is not supported\n", __func__);
         return 1;
     }
 
@@ -153,13 +176,23 @@ int main(int argc, char ** argv) {
     const int n_prompts = prompts.size();
     struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
 
+    // count number of embeddings
+    int n_embd_count = 0;
+    if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
+        for (int k = 0; k < n_prompts; k++) {
+            n_embd_count += inputs[k].size();
+        }
+    } else {
+        n_embd_count = n_prompts;
+    }
+
     // allocate output
     const int n_embd = llama_n_embd(model);
-    std::vector<float> embeddings(n_prompts * n_embd, 0);
+    std::vector<float> embeddings(n_embd_count * n_embd, 0);
     float * emb = embeddings.data();
 
     // break into batches
-    int p = 0; // number of prompts processed already
+    int e = 0; // number of embeddings already stored
     int s = 0; // number of prompts in current batch
     for (int k = 0; k < n_prompts; k++) {
         // clamp to n_batch tokens
@@ -169,11 +202,11 @@ int main(int argc, char ** argv) {
 
         // encode if at capacity
         if (batch.n_tokens + n_toks > n_batch) {
-            float * out = emb + p * n_embd;
+            float * out = emb + e * n_embd;
             batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
-            llama_batch_clear(batch);
-            p += s;
+            e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
             s = 0;
+            llama_batch_clear(batch);
         }
 
         // add to batch
@@ -182,39 +215,62 @@ int main(int argc, char ** argv) {
     }
 
     // final batch
-    float * out = emb + p * n_embd;
+    float * out = emb + e * n_embd;
     batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
 
     if (params.embd_out.empty()) {
-        // print the first part of the embeddings or for a single prompt, the full embedding
         fprintf(stdout, "\n");
-        for (int j = 0; j < n_prompts; j++) {
-            fprintf(stdout, "embedding %d: ", j);
-            for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
-                if (params.embd_normalize == 0) {
-                    fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
-                } else {
-                    fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
+
+        if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
+            for (int j = 0; j < n_embd_count; j++) {
+                fprintf(stdout, "embedding %d: ", j);
+                for (int i = 0; i < std::min(3, n_embd); i++) {
+                    if (params.embd_normalize == 0) {
+                        fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
+                    } else {
+                        fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
+                    }
+                }
+                fprintf(stdout, " ... ");
+                for (int i = n_embd - 3; i < n_embd; i++) {
+                    if (params.embd_normalize == 0) {
+                        fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
+                    } else {
+                        fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
+                    }
                 }
+                fprintf(stdout, "\n");
             }
-            fprintf(stdout, "\n");
-        }
-
-        // print cosine similarity matrix
-        if (n_prompts > 1) {
-            fprintf(stdout, "\n");
-            printf("cosine similarity matrix:\n\n");
-            for (int i = 0; i < n_prompts; i++) {
-                fprintf(stdout, "%6.6s ", prompts[i].c_str());
+        } else {
+            // print the first part of the embeddings or for a single prompt, the full embedding
+            for (int j = 0; j < n_prompts; j++) {
+                fprintf(stdout, "embedding %d: ", j);
+                for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
+                    if (params.embd_normalize == 0) {
+                        fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
+                    } else {
+                        fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
+                    }
+                }
+                fprintf(stdout, "\n");
             }
-            fprintf(stdout, "\n");
-            for (int i = 0; i < n_prompts; i++) {
-                for (int j = 0; j < n_prompts; j++) {
-                    float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
-                    fprintf(stdout, "%6.2f ", sim);
+
+            // print cosine similarity matrix
+            if (n_prompts > 1) {
+                fprintf(stdout, "\n");
+                printf("cosine similarity matrix:\n\n");
+                for (int i = 0; i < n_prompts; i++) {
+                    fprintf(stdout, "%6.6s ", prompts[i].c_str());
                 }
-                fprintf(stdout, "%1.10s", prompts[i].c_str());
                 fprintf(stdout, "\n");
+                for (int i = 0; i < n_prompts; i++) {
+                    for (int j = 0; j < n_prompts; j++) {
+                        float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+                        fprintf(stdout, "%6.2f ", sim);
+                    }
+                    fprintf(stdout, "%1.10s", prompts[i].c_str());
+                    fprintf(stdout, "\n");
+                }
             }
         }
     }
@@ -233,23 +289,23 @@ int main(int argc, char ** argv) {
             }
             fprintf(stdout, notArray ? "]\n    }" : "]");
             j++;
-            if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break;
+            if (j < n_embd_count) fprintf(stdout, notArray ? ",\n" : ","); else break;
         }
         fprintf(stdout, notArray ? "\n  ]" : "]\n");
 
         if (params.embd_out == "json+" && n_prompts > 1) {
             fprintf(stdout, ",\n  \"cosineSimilarity\": [\n");
-            for (int i = 0;;) { // at least two iteration (n_prompts > 1)
+            for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
                 fprintf(stdout, "    [");
-                for (int j = 0;;) { // at least two iteration (n_prompts > 1)
+                for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
                     float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
                     fprintf(stdout, "%6.2f", sim);
                     j++;
-                    if (j < n_prompts) fprintf(stdout, ", "); else break;
+                    if (j < n_embd_count) fprintf(stdout, ", "); else break;
                 }
                 fprintf(stdout, " ]");
                 i++;
-                if (i < n_prompts) fprintf(stdout, ",\n"); else break;
+                if (i < n_embd_count) fprintf(stdout, ",\n"); else break;
             }
             fprintf(stdout, "\n  ]");
         }
index 89efe0c800964a22dc9463419d25963fbe66c56c..f63ec450a4e09aa137d31fdf23d0ec36c1e7fec1 100644 (file)
@@ -217,6 +217,7 @@ class MODEL_ARCH(IntEnum):
     CHATGLM      = auto()
     BITNET       = auto()
     T5           = auto()
+    T5ENCODER    = auto()
     JAIS         = auto()
 
 
@@ -344,6 +345,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.CHATGLM:        "chatglm",
     MODEL_ARCH.BITNET:         "bitnet",
     MODEL_ARCH.T5:             "t5",
+    MODEL_ARCH.T5ENCODER:      "t5encoder",
     MODEL_ARCH.JAIS:           "jais",
 }
 
@@ -1036,6 +1038,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ENC_FFN_UP,
         MODEL_TENSOR.ENC_OUTPUT_NORM,
     ],
+    MODEL_ARCH.T5ENCODER: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ENC_ATTN_NORM,
+        MODEL_TENSOR.ENC_ATTN_Q,
+        MODEL_TENSOR.ENC_ATTN_K,
+        MODEL_TENSOR.ENC_ATTN_V,
+        MODEL_TENSOR.ENC_ATTN_OUT,
+        MODEL_TENSOR.ENC_ATTN_REL_B,
+        MODEL_TENSOR.ENC_FFN_NORM,
+        MODEL_TENSOR.ENC_FFN_GATE,
+        MODEL_TENSOR.ENC_FFN_DOWN,
+        MODEL_TENSOR.ENC_FFN_UP,
+        MODEL_TENSOR.ENC_OUTPUT_NORM,
+    ],
     MODEL_ARCH.JAIS: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 66c266298e86f0204f6d466368abd8d86465a99c..ce07f4fac8f100ef1a365b59aed27ea267877998 100644 (file)
@@ -504,6 +504,9 @@ extern "C" {
     // Returns true if the model contains an encoder that requires llama_encode() call
     LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
 
+    // Returns true if the model contains a decoder that requires llama_decode() call
+    LLAMA_API bool llama_model_has_decoder(const struct llama_model * model);
+
     // For encoder-decoder models, this function returns id of the token that must be provided
     // to the decoder to start generating output sequence. For other models, it returns -1.
     LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
index 97dd1b3fea4b9444d4a6a1b714091a706613c412..9c4f2aa72164050829007691b063d668654f1868 100644 (file)
@@ -208,6 +208,7 @@ enum llm_arch {
     LLM_ARCH_CHATGLM,
     LLM_ARCH_BITNET,
     LLM_ARCH_T5,
+    LLM_ARCH_T5ENCODER,
     LLM_ARCH_JAIS,
     LLM_ARCH_UNKNOWN,
 };
@@ -252,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_CHATGLM,         "chatglm"      },
     { LLM_ARCH_BITNET,          "bitnet"       },
     { LLM_ARCH_T5,              "t5"           },
+    { LLM_ARCH_T5ENCODER,       "t5encoder"    },
     { LLM_ARCH_JAIS,            "jais"         },
     { LLM_ARCH_UNKNOWN,         "(unknown)"    },
 };
@@ -1261,6 +1263,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_ENC_FFN_UP,           "enc.blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_T5ENCODER,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,           "token_embd" },
+            { LLM_TENSOR_OUTPUT,               "output" },
+            { LLM_TENSOR_ENC_OUTPUT_NORM,      "enc.output_norm" },
+            { LLM_TENSOR_ENC_ATTN_NORM,        "enc.blk.%d.attn_norm" },
+            { LLM_TENSOR_ENC_ATTN_Q,           "enc.blk.%d.attn_q" },
+            { LLM_TENSOR_ENC_ATTN_K,           "enc.blk.%d.attn_k" },
+            { LLM_TENSOR_ENC_ATTN_V,           "enc.blk.%d.attn_v" },
+            { LLM_TENSOR_ENC_ATTN_OUT,         "enc.blk.%d.attn_o" },
+            { LLM_TENSOR_ENC_ATTN_REL_B,       "enc.blk.%d.attn_rel_b" },
+            { LLM_TENSOR_ENC_FFN_NORM,         "enc.blk.%d.ffn_norm" },
+            { LLM_TENSOR_ENC_FFN_GATE,         "enc.blk.%d.ffn_gate" },
+            { LLM_TENSOR_ENC_FFN_DOWN,         "enc.blk.%d.ffn_down" },
+            { LLM_TENSOR_ENC_FFN_UP,           "enc.blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_JAIS,
         {
@@ -5187,6 +5207,12 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                }
             } break;
+        case LLM_ARCH_T5ENCODER:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
+                model.type = e_model::MODEL_UNKNOWN;
+            } break;
         case LLM_ARCH_JAIS:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -7421,6 +7447,42 @@ static bool llm_load_tensors(
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff});
                     }
                 } break;
+            case LLM_ARCH_T5ENCODER:
+                {
+                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
+
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm_enc  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd});
+                        layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+                        layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up_enc   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                    }
+                } break;
             case LLM_ARCH_JAIS:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -13135,7 +13197,7 @@ struct llm_build_context {
         return gf;
     }
 
-    struct ggml_cgraph * build_t5() {
+    struct ggml_cgraph * build_t5_encoder() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
@@ -13150,303 +13212,323 @@ struct llm_build_context {
 
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
-        if (lctx.is_encoding) {
-            struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
+        GGML_ASSERT(lctx.is_encoding);
+        struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
 
-            // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-            struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
-
-            for (int il = 0; il < n_layer; ++il) {
-                struct ggml_tensor * inpSA = inpL;
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
 
-                // norm
-                cur = llm_build_norm(ctx0, inpL, hparams,
-                        model.layers[il].attn_norm_enc, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_norm", il);
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
 
-                // self-attention
-                {
-                    struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur);
-                    cb(Qcur, "Qcur", il);
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm_enc, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
 
-                    struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur);
-                    cb(Kcur, "Kcur", il);
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur);
+                cb(Qcur, "Qcur", il);
 
-                    struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur);
-                    cb(Vcur, "Vcur", il);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur);
+                cb(Kcur, "Kcur", il);
 
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur);
+                cb(Vcur, "Vcur", il);
 
-                    struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-                    struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
-                    struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                    cb(kq, "kq", il);
+                struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
 
-                    struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
-                    struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
-                    struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
-                    cb(kq_b, "kq_b", il);
+                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                cb(kq, "kq", il);
 
-                    kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
-                    cb(kq, "kq_soft_max_ext", il);
+                struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
+                struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
+                struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
+                cb(kq_b, "kq_b", il);
 
-                    struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
-                    cb(v, "v", il);
+                kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
+                cb(kq, "kq_soft_max_ext", il);
 
-                    struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
-                    cb(kqv, "kqv", il);
+                struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
+                cb(v, "v", il);
 
-                    struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                    cb(kqv_merged, "kqv_merged", il);
+                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
+                cb(kqv, "kqv", il);
 
-                    cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                    cb(cur, "kqv_merged_cont", il);
+                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                cb(kqv_merged, "kqv_merged", il);
 
-                    ggml_build_forward_expand(gf, cur);
+                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                cb(cur, "kqv_merged_cont", il);
 
-                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur);
-                    cb(cur, "kqv_out", il);
-                }
+                ggml_build_forward_expand(gf, cur);
 
-                if (il == n_layer - 1) {
-                    // skip computing output for unused tokens
-                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                    n_tokens = n_outputs;
-                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-                }
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur);
+                cb(cur, "kqv_out", il);
+            }
 
-                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-                cb(ffn_inp, "ffn_inp", il);
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
 
-                // feed-forward network
-                {
-                    cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                            model.layers[il].ffn_norm_enc, NULL,
-                            LLM_NORM_RMS, cb, il);
-                    cb(cur, "ffn_norm", il);
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
 
-                    // T5 uses relu, flan-T5 uses gelu-gated
-                    cur = llm_build_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_up_enc,   NULL, NULL,
-                            model.layers[il].ffn_gate_enc, NULL, NULL,
-                            model.layers[il].ffn_down_enc, NULL, NULL,
-                            NULL,
-                            model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
-                            model.layers[il].ffn_gate_enc ? LLM_FFN_PAR  : LLM_FFN_SEQ,
-                            cb, il);
-                    cb(cur, "ffn_out", il);
-                }
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm_enc, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
 
-                cur = ggml_add(ctx0, cur, ffn_inp);
+                // T5 uses relu, flan-T5 uses gelu-gated
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up_enc,   NULL, NULL,
+                        model.layers[il].ffn_gate_enc, NULL, NULL,
+                        model.layers[il].ffn_down_enc, NULL, NULL,
+                        NULL,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR  : LLM_FFN_SEQ,
+                        cb, il);
                 cb(cur, "ffn_out", il);
+            }
 
-                ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-                if (layer_dir != nullptr) {
-                    cur = ggml_add(ctx0, cur, layer_dir);
-                }
-                cb(cur, "l_out", il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
 
-                // input for next layer
-                inpL = cur;
+            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+            if (layer_dir != nullptr) {
+                cur = ggml_add(ctx0, cur, layer_dir);
             }
+            cb(cur, "l_out", il);
 
-            cur = inpL;
-            cb(cur, "result_embd", -1);
+            // input for next layer
+            inpL = cur;
+        }
 
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.output_norm_enc, NULL,
-                    LLM_NORM_RMS, cb, -1);
-            cb(cur, "result_norm", -1);
-        } else {
-            GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
+        cur = inpL;
+        cb(cur, "result_embd", -1);
 
-            struct ggml_tensor * embd_enc       = llm_build_inp_embd_enc();
-            struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm_enc, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
 
-            struct ggml_tensor * KQ_mask_dec   = build_inp_KQ_mask();
-            struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
+        ggml_build_forward_expand(gf, cur);
 
-            for (int il = 0; il < n_layer; ++il) {
-                struct ggml_tensor * inpSA = inpL;
+        return gf;
+    }
 
-                // norm
-                cur = llm_build_norm(ctx0, inpL, hparams,
-                        model.layers[il].attn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_norm", il);
+    struct ggml_cgraph * build_t5_decoder() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
-                // self-attention
-                {
-                    struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                    cb(Qcur, "Qcur", il);
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
 
-                    struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                    cb(Kcur, "Kcur", il);
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
-                    struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                    cb(Vcur, "Vcur", il);
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
 
-                    llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
-                    struct ggml_tensor * k =
-                        ggml_view_3d(ctx0, kv_self.k_l[il],
-                                n_embd_head_k, n_kv, n_head_kv,
-                                ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                                ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
-                                0);
-                    cb(k, "k", il);
+        GGML_ASSERT(!lctx.is_encoding);
+        GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
 
-                    struct ggml_tensor * v =
-                        ggml_view_3d(ctx0, kv_self.v_l[il],
-                                n_kv, n_embd_head_v, n_head_kv,
-                                ggml_element_size(kv_self.v_l[il])*n_ctx,
-                                ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
-                                0);
-                    cb(v, "v", il);
+        struct ggml_tensor * embd_enc       = llm_build_inp_embd_enc();
+        struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
 
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+        struct ggml_tensor * KQ_mask_dec   = build_inp_KQ_mask();
+        struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
 
-                    struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
 
-                    struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                    cb(kq, "kq", il);
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
 
-                    struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
-                    struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
-                    struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
-                    cb(kq_b, "kq_b", il);
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
 
-                    kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
-                    cb(kq, "kq_soft_max_ext", il);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
 
-                    struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
-                    cb(kqv, "kqv", il);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
 
-                    struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                    cb(kqv_merged, "kqv_merged", il);
+                llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
 
-                    cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                    cb(cur, "kqv_merged_cont", il);
+                struct ggml_tensor * k =
+                    ggml_view_3d(ctx0, kv_self.k_l[il],
+                            n_embd_head_k, n_kv, n_head_kv,
+                            ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                            ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+                            0);
+                cb(k, "k", il);
 
-                    ggml_build_forward_expand(gf, cur);
+                struct ggml_tensor * v =
+                    ggml_view_3d(ctx0, kv_self.v_l[il],
+                            n_kv, n_embd_head_v, n_head_kv,
+                            ggml_element_size(kv_self.v_l[il])*n_ctx,
+                            ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
+                            0);
+                cb(v, "v", il);
 
-                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-                    cb(cur, "kqv_out", il);
-                }
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = ggml_add(ctx0, cur, inpSA);
-                cb(cur, "cross_inp", il);
+                struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
 
-                struct ggml_tensor * inpCA = cur;
+                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                cb(kq, "kq", il);
 
-                // norm
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].attn_norm_cross, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_norm_cross", il);
+                struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
+                struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
+                struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
+                cb(kq_b, "kq_b", il);
 
-                // cross-attention
-                {
-                    struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur);
-                    cb(Qcur, "Qcur", il);
+                kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
+                cb(kq, "kq_soft_max_ext", il);
 
-                    struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc);
-                    cb(Kcur, "Kcur", il);
+                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+                cb(kqv, "kqv", il);
 
-                    struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc);
-                    cb(Vcur, "Vcur", il);
+                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                cb(kqv_merged, "kqv_merged", il);
 
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
+                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                cb(cur, "kqv_merged_cont", il);
 
-                    struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-                    struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+                ggml_build_forward_expand(gf, cur);
 
-                    struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                    cb(kq, "kq", il);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
+                cb(cur, "kqv_out", il);
+            }
 
-                    kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
-                    cb(kq, "kq_soft_max_ext", il);
+            cur = ggml_add(ctx0, cur, inpSA);
+            cb(cur, "cross_inp", il);
 
-                    struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
-                    cb(v, "v", il);
+            struct ggml_tensor * inpCA = cur;
 
-                    struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
-                    cb(kqv, "kqv", il);
+            // norm
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.layers[il].attn_norm_cross, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm_cross", il);
 
-                    struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                    cb(kqv_merged, "kqv_merged", il);
+            // cross-attention
+            {
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur);
+                cb(Qcur, "Qcur", il);
 
-                    cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                    cb(cur, "kqv_merged_cont", il);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc);
+                cb(Kcur, "Kcur", il);
 
-                    ggml_build_forward_expand(gf, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc);
+                cb(Vcur, "Vcur", il);
 
-                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur);
-                    cb(cur, "kqv_out", il);
-                }
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
 
-                if (il == n_layer - 1) {
-                    // skip computing output for unused tokens
-                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                    n_tokens = n_outputs;
-                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-                    inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
-                }
+                struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
 
-                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
-                cb(ffn_inp, "ffn_inp", il);
+                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                cb(kq, "kq", il);
 
-                // feed-forward network
-                {
-                    cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                            model.layers[il].ffn_norm, NULL,
-                            LLM_NORM_RMS, cb, il);
-                    cb(cur, "ffn_norm", il);
+                kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
+                cb(kq, "kq_soft_max_ext", il);
 
-                    // T5 uses relu, flan-T5 uses gelu-gated
-                    cur = llm_build_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_up,   NULL, NULL,
-                            model.layers[il].ffn_gate, NULL, NULL,
-                            model.layers[il].ffn_down, NULL, NULL,
-                            NULL,
-                            model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
-                            model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
-                            cb, il);
-                    cb(cur, "ffn_out", il);
-                }
+                struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
+                cb(v, "v", il);
 
-                cur = ggml_add(ctx0, cur, ffn_inp);
-                cb(cur, "ffn_out", il);
+                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
+                cb(kqv, "kqv", il);
 
-                ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-                if (layer_dir != nullptr) {
-                    cur = ggml_add(ctx0, cur, layer_dir);
-                }
-                cb(cur, "l_out", il);
+                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                cb(kqv_merged, "kqv_merged", il);
 
-                // input for next layer
-                inpL = cur;
+                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                cb(cur, "kqv_merged_cont", il);
+
+                ggml_build_forward_expand(gf, cur);
+
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur);
+                cb(cur, "kqv_out", il);
             }
 
-            cur = inpL;
-            cb(cur, "result_embd", -1);
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
+            }
 
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.output_norm, NULL,
-                    LLM_NORM_RMS, cb, -1);
-            cb(cur, "result_norm", -1);
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
+            cb(ffn_inp, "ffn_inp", il);
 
-            // lm_head
-            cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-            cb(cur, "result_output", -1);
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                // T5 uses relu, flan-T5 uses gelu-gated
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
+                        cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+            if (layer_dir != nullptr) {
+                cur = ggml_add(ctx0, cur, layer_dir);
+            }
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
 
+        cur = inpL;
+        cb(cur, "result_embd", -1);
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
         ggml_build_forward_expand(gf, cur);
 
         return gf;
@@ -13898,7 +13980,15 @@ static struct ggml_cgraph * llama_build_graph(
             } break;
         case LLM_ARCH_T5:
             {
-                result = llm.build_t5();
+                if (lctx.is_encoding) {
+                    result = llm.build_t5_encoder();
+                } else {
+                    result = llm.build_t5_decoder();
+                }
+            } break;
+        case LLM_ARCH_T5ENCODER:
+            {
+                result = llm.build_t5_encoder();
             } break;
         case LLM_ARCH_JAIS:
             {
@@ -14346,7 +14436,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
 
     // TODO: use a per-batch flag for logits presence instead
     const bool has_logits = !cparams.embeddings;
-    const bool has_embd   =  lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
+    const bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
 
     const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
     const size_t embd_size   = has_embd   ?  n_embd*n_outputs_max : 0;
@@ -14829,9 +14919,24 @@ static int llama_encode_internal(
     ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
 
     // the output embeddings after the final encoder normalization
-    struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1];
+    struct ggml_tensor * embd = nullptr;
 
-    GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
+    // there are two cases here
+    if (llama_model_has_decoder(&lctx.model)) {
+        // first case is an encoder-decoder T5 model where embeddings are passed to decoder
+        embd = gf->nodes[gf->n_nodes - 1];
+        GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor");
+    } else {
+        // second case is an encoder-only T5 model
+        if (cparams.embeddings) {
+            // only output embeddings if required
+            embd = gf->nodes[gf->n_nodes - 1];
+            if (strcmp(embd->name, "result_embd_pooled") != 0) {
+                embd = gf->nodes[gf->n_nodes - 2];
+            }
+            GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
+        }
+    }
 
     ggml_backend_sched_alloc_graph(lctx.sched, gf);
 
@@ -14844,20 +14949,54 @@ static int llama_encode_internal(
         ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
         GGML_ASSERT(backend_embd != nullptr);
 
-        // extract token embeddings
-        GGML_ASSERT(lctx.embd != nullptr);
+        if (llama_model_has_decoder(&lctx.model)) {
+            lctx.embd_enc.resize(n_tokens*n_embd);
+            float * embd_out = lctx.embd_enc.data();
 
-        lctx.embd_enc.resize(n_tokens*n_embd);
-        float * embd_out = lctx.embd_enc.data();
+            ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
 
-        ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
+            // remember the sequence ids used during the encoding - needed for cross attention later
+            lctx.seq_ids_enc.resize(n_tokens);
+            for (uint32_t i = 0; i < n_tokens; i++) {
+                for (int s = 0; s < batch.n_seq_id[i]; s++) {
+                    llama_seq_id seq_id = batch.seq_id[i][s];
+                    lctx.seq_ids_enc[i].insert(seq_id);
+                }
+            }
+        } else {
+            GGML_ASSERT(lctx.embd != nullptr);
 
-        // remember the sequence ids used during the encoding - needed for cross attention later
-        lctx.seq_ids_enc.resize(n_tokens);
-        for (uint32_t i = 0; i < n_tokens; i++) {
-            for (int s = 0; s < batch.n_seq_id[i]; s++) {
-                llama_seq_id seq_id = batch.seq_id[i][s];
-                lctx.seq_ids_enc[i].insert(seq_id);
+            switch (cparams.pooling_type) {
+                case LLAMA_POOLING_TYPE_NONE:
+                    {
+                        // extract token embeddings
+                        GGML_ASSERT(lctx.embd != nullptr);
+                        float * embd_out = lctx.embd;
+
+                        GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
+                        ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
+                    } break;
+                case LLAMA_POOLING_TYPE_MEAN:
+                case LLAMA_POOLING_TYPE_CLS:
+                case LLAMA_POOLING_TYPE_LAST:
+                    {
+                        // extract sequence embeddings
+                        auto & embd_seq_out = lctx.embd_seq;
+                        embd_seq_out.clear();
+
+                        for (uint32_t i = 0; i < n_tokens; i++) {
+                            const llama_seq_id seq_id = batch.seq_id[i][0];
+                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                                continue;
+                            }
+                            embd_seq_out[seq_id].resize(n_embd);
+                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
+                        }
+                    } break;
+                case LLAMA_POOLING_TYPE_UNSPECIFIED:
+                    {
+                        GGML_ABORT("unknown pooling type");
+                    }
             }
         }
     }
@@ -16567,6 +16706,8 @@ struct llama_context * llama_new_context_with_model(
 
     ctx->sampling.rng = std::mt19937(params.seed);
     ctx->logits_all   = params.logits_all;
+    // build worst-case graph for encoder if a model contains encoder
+    ctx->is_encoding  = llama_model_has_encoder(model);
 
     uint32_t kv_size = cparams.n_ctx;
     ggml_type type_k = params.type_k;
@@ -16881,6 +17022,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_MAMBA:
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_T5:
+        case LLM_ARCH_T5ENCODER:
         case LLM_ARCH_JAIS:
             return LLAMA_ROPE_TYPE_NONE;
 
@@ -17028,8 +17170,16 @@ struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const ch
 
 bool llama_model_has_encoder(const struct llama_model * model) {
     switch (model->arch) {
-        case LLM_ARCH_T5: return true;
-        default:          return false;
+        case LLM_ARCH_T5:        return true;
+        case LLM_ARCH_T5ENCODER: return true;
+        default:                 return false;
+    }
+}
+
+bool llama_model_has_decoder(const struct llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_T5ENCODER: return false;
+        default:                 return true;
     }
 }