]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Wed, 28 Aug 2024 08:04:02 +0000 (11:04 +0300)
committerGeorgi Gerganov <redacted>
Wed, 28 Aug 2024 10:22:20 +0000 (13:22 +0300)
examples/talk-llama/llama-impl.h
examples/talk-llama/llama-sampling.cpp
examples/talk-llama/llama-vocab.cpp
examples/talk-llama/llama-vocab.h
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h

index dcc8c1c15a1b1ea776e5f9a29b9bfc7fd5d310c5..9527740961da652657d15b898d973ad8aac6d33c 100644 (file)
@@ -24,3 +24,24 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
 #define LLAMA_LOG_INFO(...)  llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
 #define LLAMA_LOG_WARN(...)  llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
 #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+
+//
+// helpers
+//
+
+static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
+    if (search.empty()) {
+        return;
+    }
+    std::string builder;
+    builder.reserve(s.length());
+    size_t pos = 0;
+    size_t last_pos = 0;
+    while ((pos = s.find(search, last_pos)) != std::string::npos) {
+        builder.append(s, last_pos, pos - last_pos);
+        builder.append(replace);
+        last_pos = pos + search.length();
+    }
+    builder.append(s, last_pos, std::string::npos);
+    s = std::move(builder);
+}
index 8910f6d6542e91069d059d83105bbd63108a795a..8f4841d9daf7b90f681eca1d0c989e761d205181 100644 (file)
@@ -85,14 +85,14 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
             constexpr float bucket_low   = -10.0f;
             constexpr float bucket_high  =  10.0f;
             constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
-            constexpr float bucker_inter = -bucket_low * bucket_scale;
+            constexpr float bucket_inter = -bucket_low * bucket_scale;
 
             std::vector<int> bucket_idx(candidates->size);
             std::vector<int> histo(nbuckets, 0);
 
             for (int i = 0; i < (int)candidates->size; ++i) {
                 const float val = candidates->data[i].logit;
-                int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
+                int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
                 ib = std::max(0, std::min(nbuckets-1, ib));
                 bucket_idx[i] = ib;
                 ++histo[ib];
index e6d6059d034828513f5201e2af56d4e1881b8aaf..323660ef54cb07a7c90f15d86620583375754fc3 100644 (file)
 // helpers
 //
 
-static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
-    std::string result;
-    for (size_t pos = 0; ; pos += search.length()) {
-        auto new_pos = s.find(search, pos);
-        if (new_pos == std::string::npos) {
-            result += s.substr(pos, s.size() - pos);
-            break;
-        }
-        result += s.substr(pos, new_pos - pos) + replace;
-        pos = new_pos;
-    }
-    s = std::move(result);
-}
-
 LLAMA_ATTRIBUTE_FORMAT(1, 2)
 static std::string format(const char * fmt, ...) {
     va_list ap;
@@ -335,6 +321,21 @@ private:
 
 // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
 
+template<typename T, typename Container = std::vector<T>, typename Compare = std::less<typename Container::value_type>>
+class llama_priority_queue : public std::priority_queue<T, Container, Compare> {
+public:
+    using std::priority_queue<T, Container, Compare>::priority_queue;
+
+    T pop_move() {
+        T item = std::move(this->c.front());
+        std::pop_heap(this->c.begin(), this->c.end(), this->comp);
+        this->c.pop_back();
+        return item;
+    }
+
+    void pop() =  delete;
+};
+
 struct llm_bigram_bpe {
     struct comparator {
         bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
@@ -343,7 +344,7 @@ struct llm_bigram_bpe {
     };
 
     using queue_storage = std::vector<llm_bigram_bpe>;
-    using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>;
+    using queue = llama_priority_queue<llm_bigram_bpe, queue_storage, comparator>;
     llm_symbol::index left;
     llm_symbol::index right;
     std::string text;
@@ -402,6 +403,7 @@ struct llm_tokenizer_bpe {
             case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
             case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
             case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
+            case LLAMA_VOCAB_PRE_TYPE_EXAONE:
                 regex_exprs = {
                     "\\p{N}",
                     "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
@@ -424,6 +426,8 @@ struct llm_tokenizer_bpe {
                 };
                 break;
             case LLAMA_VOCAB_PRE_TYPE_PORO:
+            case LLAMA_VOCAB_PRE_TYPE_BLOOM:
+            case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH:
                 regex_exprs = {
                     " ?[^(\\s|.,!?…。,、।۔،)]+",
                 };
@@ -531,8 +535,7 @@ struct llm_tokenizer_bpe {
 
             // build token(s)
             while (!work_queue.empty()) {
-                auto bigram = work_queue.top();
-                work_queue.pop();
+                auto bigram = work_queue.pop_move();
 
                 auto & left_symbol = symbols[bigram.left];
                 auto & right_symbol = symbols[bigram.right];
@@ -1480,11 +1483,11 @@ llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
     return vocab.special_pad_id;
 }
 
-int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) {
+bool llama_add_bos_token_impl(const struct llama_vocab & vocab) {
     return vocab.tokenizer_add_bos;
 }
 
-int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) {
+bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
     return vocab.tokenizer_add_eos;
 }
 
index 7adfc16da3af3839fb10e202d9d2460368948e56..6e8f30be43ba1cb2f93cd07d8189546998de49d1 100644 (file)
@@ -95,8 +95,8 @@ llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
 llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
 llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
 
-int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab);
-int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab);
+bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
+bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
 
 llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
 llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
index a7b1c9ebd9e37d1e9017bdec315b86ecd51a738b..8d5f24783d6aba1d5fad15718f74c4586aaf1410 100644 (file)
@@ -121,17 +121,6 @@ static std::string trim(const std::string & str) {
     return str.substr(start, end - start);
 }
 
-static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
-    if (search.empty()) {
-        return; // Avoid infinite loop if 'search' is an empty string
-    }
-    size_t pos = 0;
-    while ((pos = s.find(search, pos)) != std::string::npos) {
-        s.replace(pos, search.length(), replace);
-        pos += replace.length();
-    }
-}
-
 static bool is_float_close(float a, float b, float abs_tol) {
     // Check for non-negative tolerance
     if (abs_tol < 0.0) {
@@ -219,7 +208,10 @@ enum llm_arch {
     LLM_ARCH_CHATGLM,
     LLM_ARCH_BITNET,
     LLM_ARCH_T5,
+    LLM_ARCH_T5ENCODER,
     LLM_ARCH_JAIS,
+    LLM_ARCH_NEMOTRON,
+    LLM_ARCH_EXAONE,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -263,7 +255,10 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_CHATGLM,         "chatglm"      },
     { LLM_ARCH_BITNET,          "bitnet"       },
     { LLM_ARCH_T5,              "t5"           },
+    { LLM_ARCH_T5ENCODER,       "t5encoder"    },
     { LLM_ARCH_JAIS,            "jais"         },
+    { LLM_ARCH_NEMOTRON,        "nemotron"     },
+    { LLM_ARCH_EXAONE,          "exaone"       },
     { LLM_ARCH_UNKNOWN,         "(unknown)"    },
 };
 
@@ -333,6 +328,7 @@ enum llm_kv {
     LLM_KV_SSM_CONV_KERNEL,
     LLM_KV_SSM_STATE_SIZE,
     LLM_KV_SSM_TIME_STEP_RANK,
+    LLM_KV_SSM_DT_B_C_RMS,
 
     LLM_KV_TOKENIZER_MODEL,
     LLM_KV_TOKENIZER_PRE,
@@ -431,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_SSM_INNER_SIZE,                "%s.ssm.inner_size"     },
     { LLM_KV_SSM_STATE_SIZE,                "%s.ssm.state_size"     },
     { LLM_KV_SSM_TIME_STEP_RANK,            "%s.ssm.time_step_rank" },
+    { LLM_KV_SSM_DT_B_C_RMS,                "%s.ssm.dt_b_c_rms" },
 
     { LLM_KV_TOKENIZER_MODEL,                "tokenizer.ggml.model"                    },
     { LLM_KV_TOKENIZER_PRE,                  "tokenizer.ggml.pre"                      },
@@ -1272,6 +1269,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_ENC_FFN_UP,           "enc.blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_T5ENCODER,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,           "token_embd" },
+            { LLM_TENSOR_OUTPUT,               "output" },
+            { LLM_TENSOR_ENC_OUTPUT_NORM,      "enc.output_norm" },
+            { LLM_TENSOR_ENC_ATTN_NORM,        "enc.blk.%d.attn_norm" },
+            { LLM_TENSOR_ENC_ATTN_Q,           "enc.blk.%d.attn_q" },
+            { LLM_TENSOR_ENC_ATTN_K,           "enc.blk.%d.attn_k" },
+            { LLM_TENSOR_ENC_ATTN_V,           "enc.blk.%d.attn_v" },
+            { LLM_TENSOR_ENC_ATTN_OUT,         "enc.blk.%d.attn_o" },
+            { LLM_TENSOR_ENC_ATTN_REL_B,       "enc.blk.%d.attn_rel_b" },
+            { LLM_TENSOR_ENC_FFN_NORM,         "enc.blk.%d.ffn_norm" },
+            { LLM_TENSOR_ENC_FFN_GATE,         "enc.blk.%d.ffn_gate" },
+            { LLM_TENSOR_ENC_FFN_DOWN,         "enc.blk.%d.ffn_down" },
+            { LLM_TENSOR_ENC_FFN_UP,           "enc.blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_JAIS,
         {
@@ -1287,6 +1302,43 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
         },
     },
+    {
+        LLM_ARCH_NEMOTRON,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
+    {
+        LLM_ARCH_EXAONE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -2187,6 +2239,7 @@ struct llama_hparams {
     uint32_t ssm_d_inner = 0;
     uint32_t ssm_d_state = 0;
     uint32_t ssm_dt_rank = 0;
+    bool ssm_dt_b_c_rms = false;
 
     float f_clamp_kqv      = 0.0f;
     float f_max_alibi_bias = 0.0f;
@@ -2236,6 +2289,7 @@ struct llama_hparams {
         if (this->ssm_d_inner != other.ssm_d_inner) return true;
         if (this->ssm_d_state != other.ssm_d_state) return true;
         if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
+        if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
 
         if (this->dec_start_token_id != other.dec_start_token_id) return true;
 
@@ -2462,10 +2516,29 @@ struct llama_layer {
     struct ggml_tensor * ffn_down_scale;
 };
 
+// very similar to llama_batch,
+// but has more metadata about sequences
+struct llama_ubatch {
+    bool equal_seqs;
+    // TODO: whole_seqs for embeddings?
+
+    uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
+    uint32_t n_seq_tokens; // tokens per sequence
+    uint32_t n_seqs;
+
+    llama_token  *  token;    // [n_tokens]
+    float        *  embd;     // [n_embd, n_tokens]
+    llama_pos    *  pos;      // [n_tokens]
+    int32_t      *  n_seq_id; // [n_seqs]
+    llama_seq_id ** seq_id;   // [n_seqs]
+    int8_t       *  output;   // [n_tokens]
+};
+
 struct llama_kv_cell {
     llama_pos pos   = -1;
     llama_pos delta = 0;
-    int32_t   src   = 0; // used by recurrent state models to copy states
+    int32_t   src   = -1; // used by recurrent state models to copy states
+    int32_t   tail  = -1;
 
     std::set<llama_seq_id> seq_id;
 
@@ -2486,7 +2559,6 @@ struct llama_kv_cell {
 struct llama_kv_cache {
     bool has_shift = false;
     bool do_defrag = false;
-    bool do_copy   = false;
     bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
     bool v_trans   = true;  // the value tensor is transposed
 
@@ -2649,6 +2721,340 @@ struct llama_model {
     }
 };
 
+struct llama_sbatch_seq {
+    int32_t n_seq_id;
+    llama_seq_id * seq_id;
+    size_t offset;
+    size_t length;
+
+    // helper for smoother batch API transition -- can be deprecated in the future
+    llama_seq_id all_seq_id; // used if seq_id == NULL
+};
+
+// sequence-length-aware batch splitting
+struct llama_sbatch {
+    // tokens left in this batch
+    size_t n_tokens;
+
+    size_t n_embd;
+
+    bool logits_all; // TODO: remove once lctx.logits_all is removed too
+
+    // sorted indices into the batch
+    std::vector<size_t> ids;
+    // batch indices of the output
+    std::vector<size_t> out_ids;
+    std::vector<llama_sbatch_seq> seq;
+    const llama_batch * batch = nullptr;
+
+    // buffers for the ubatch
+    std::vector<llama_token>    ubatch_token;
+    std::vector<float>          ubatch_embd;
+    std::vector<llama_pos>      ubatch_pos;
+    std::vector<int32_t>        ubatch_n_seq_id;
+    std::vector<llama_seq_id *> ubatch_seq_id;
+    std::vector<int8_t>         ubatch_output;
+
+    llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) {
+        // clear empty sequences
+        // the previous ubatch is assumed to be gone,
+        // so nothing should refer to values in these sequences anymore.
+        for (size_t i = seq.size(); i-- > 0;) {
+            if (seq[i].length == 0) {
+                seq.pop_back();
+            } else {
+                break;
+            }
+        }
+        ubatch_token.resize(!has_embd ? n_ubatch : 0);
+        ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
+        ubatch_pos.resize(n_ubatch);
+        ubatch_n_seq_id.resize(n_ubatch);
+        ubatch_seq_id.resize(n_ubatch);
+        ubatch_output.resize(n_ubatch);
+        llama_ubatch ubatch = {
+            /*equal_seqs   =*/ true,
+            /*n_tokens     =*/ 0,
+            /*n_seq_tokens =*/ 0,
+            /*n_seqs       =*/ 0,
+            /*token        =*/ !has_embd ? ubatch_token.data() : nullptr,
+            /*embd         =*/ has_embd  ? ubatch_embd.data()  : nullptr,
+            /*pos          =*/ ubatch_pos.data(),
+            /*n_seq_id     =*/ ubatch_n_seq_id.data(),
+            /*seq_id       =*/ ubatch_seq_id.data(),
+            /*output       =*/ ubatch_output.data(),
+        };
+        return ubatch;
+    }
+
+    void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
+        GGML_ASSERT(batch != nullptr);
+        GGML_ASSERT(length <= seq.length);
+        // Can only add sequences of equal lengths to a batch,
+        // otherwise it isn't clear to which sequence a token belongs
+        GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
+        GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
+        // NOTE: loops are separated for cache-friendliness
+        if (batch->token) {
+            if (ubatch.equal_seqs) {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
+                }
+            } else {
+                // simple split
+                ubatch.token = batch->token + seq.offset;
+            }
+        } else {
+            ubatch.token = nullptr;
+        }
+        if (batch->embd) {
+            if (ubatch.equal_seqs) {
+                for (size_t i = 0; i < length; ++i) {
+                    memcpy(
+                        ubatch.embd + n_embd * (ubatch.n_tokens + i),
+                        batch->embd + n_embd * ids[seq.offset + i],
+                        n_embd * sizeof(float)
+                    );
+                }
+            } else {
+                // simple split
+                ubatch.embd = batch->embd + (n_embd * seq.offset);
+            }
+        } else {
+            ubatch.embd = nullptr;
+        }
+        // from here on, the else branches are deprecated;
+        // they are helpers for smoother batch API transition
+        if (batch->pos) {
+            if (ubatch.equal_seqs) {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
+                }
+            } else {
+                // simple split
+                ubatch.pos = batch->pos + seq.offset;
+            }
+        } else {
+            for (size_t i = 0; i < length; ++i) {
+                llama_pos bi = ids[seq.offset + i];
+                ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
+            }
+        }
+        if (ubatch.equal_seqs) {
+            ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
+            if (seq.seq_id) {
+                ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
+            } else {
+                GGML_ASSERT(seq.n_seq_id == 1);
+                ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
+            }
+        } else {
+            // simple split
+            if (batch->n_seq_id) {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.n_seq_id = batch->n_seq_id + seq.offset;
+                }
+            } else {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
+                }
+            }
+            if (batch->seq_id) {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.seq_id = batch->seq_id + seq.offset;
+                }
+            } else {
+                for (size_t i = 0; i < length; ++i) {
+                    ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
+                }
+            }
+        }
+        if (logits_all) {
+            for (size_t i = 0; i < length; ++i) {
+                ubatch.output[ubatch.n_tokens + i] = 1;
+                out_ids.push_back(ids[seq.offset + i]);
+            }
+        } else if (batch->logits) {
+            if (ubatch.equal_seqs) {
+                for (size_t i = 0; i < length; ++i) {
+                    size_t id = ids[seq.offset + i];
+                    int8_t is_output = batch->logits[id];
+                    ubatch.output[ubatch.n_tokens + i] = is_output;
+                    if (is_output) { out_ids.push_back(id); }
+                }
+            } else {
+                // simple split
+                ubatch.output = batch->logits + seq.offset;
+                for (size_t i = 0; i < length; ++i) {
+                    if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
+                }
+            }
+        } else {
+            // only get last output
+            for (size_t i = 0; i < length; ++i) {
+                size_t id = ids[seq.offset + i];
+                int8_t is_last = id == ids.size() - 1;
+                ubatch.output[ubatch.n_tokens + i] = is_last;
+                if (is_last) { out_ids.push_back(id); }
+            }
+        }
+        if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
+            ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
+        }
+        ubatch.n_tokens += length;
+        ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
+        seq.offset += length;
+        seq.length -= length;
+        n_tokens -= length;
+        GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
+    }
+
+    // simple split, unknown number of sequences of unequal lengths
+    llama_ubatch split_simple(size_t n_ubatch) {
+        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
+        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
+        ubatch.equal_seqs = false;
+        if (!seq.empty()) {
+            llama_sbatch_seq & s = seq[0];
+            size_t length = s.length < n_ubatch ? s.length : n_ubatch;
+            GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
+            add_seq_to_ubatch(ubatch, s, length);
+        }
+        return ubatch;
+    }
+
+    // make batches of equal-length sequences
+    llama_ubatch split_equal(size_t n_ubatch) {
+        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
+        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
+        if (!seq.empty()) {
+            size_t length = 0;
+            size_t n_tokens_in_ubatch = 0;
+            GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
+            // smallest first, because it's easier to split this way;
+            // starting from the end to pop in constant time.
+            for (size_t i = seq.size(); i-- > 0;) {
+                llama_sbatch_seq & s = seq[i];
+                GGML_ASSERT(s.length > 0);
+                if (length == 0) {
+                    length = s.length < n_ubatch ? s.length : n_ubatch;
+                }
+                add_seq_to_ubatch(ubatch, s, length);
+                n_tokens_in_ubatch += length;
+                // shared prompts can't be mixed with any of their sequences,
+                // so it's safer to compute them in their own ubatch
+                if (s.n_seq_id > 1) { break; }
+                // stop when there isn't enough space for another sequence
+                if (length + n_tokens_in_ubatch > n_ubatch) { break; }
+            }
+        }
+        return ubatch;
+    }
+
+    // sequence-wise split
+    llama_ubatch split_seq(size_t n_ubatch) {
+        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
+        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
+        if (!seq.empty()) {
+            llama_sbatch_seq & s = seq[seq.size() - 1];
+            size_t length = s.length < n_ubatch ? s.length : n_ubatch;
+            GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
+            add_seq_to_ubatch(ubatch, s, length);
+        }
+        return ubatch;
+    }
+
+    void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
+        GGML_ASSERT(batch.n_tokens >= 0);
+        this->batch = &batch;
+        this->n_embd = n_embd;
+        this->logits_all = logits_all;
+
+        n_tokens = batch.n_tokens;
+        ids.resize(n_tokens);
+        out_ids.clear();
+        // TODO: reserve out_ids and seq
+
+        for (size_t i = 0; i < n_tokens; ++i) {
+            ids[i] = i;
+        }
+        if (simple_split) {
+            seq.resize(1);
+            llama_sbatch_seq & s = seq[0];
+            s.n_seq_id = 0;
+            s.seq_id = nullptr;
+            s.offset = 0;
+            s.length = n_tokens;
+            s.all_seq_id = batch.all_seq_id;
+            return;
+        }
+        std::sort(ids.begin(), ids.end(),
+            [&batch](size_t a, size_t b) {
+                int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
+                int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
+                // sort by seq_id, then by pos
+                if (n_seq_a == n_seq_b) {
+                    if (batch.seq_id) {
+                        for (int32_t i = 0; i < n_seq_a; ++i) {
+                            llama_seq_id seq_id_a = batch.seq_id[a][i];
+                            llama_seq_id seq_id_b = batch.seq_id[b][i];
+                            // smaller seq_ids go first
+                            if (seq_id_a != seq_id_b) {
+                                return seq_id_a < seq_id_b;
+                            }
+                        }
+                    }
+                    // when all else is equal, sort by pos
+                    if (batch.pos) {
+                        return batch.pos[a] < batch.pos[b];
+                    }
+                    // no pos, sort by id (assuming batch.all_pos_1 is positive)
+                    return a < b;
+                }
+                // shared prompts go first
+                return n_seq_a > n_seq_b;
+            }
+        );
+        // init seq
+        llama_sbatch_seq * last_seq = nullptr;
+
+        if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
+            for (size_t i = 0; i < n_tokens; ++i) {
+                const size_t bi = ids[i];
+                const int32_t n_seqs = batch.n_seq_id[bi];
+                llama_seq_id * seq_ids = batch.seq_id[bi];
+                if (last_seq != nullptr) {
+                    bool same = n_seqs == last_seq->n_seq_id;
+                    for (int32_t j = 0; same && j < n_seqs; ++j) {
+                        if (seq_ids[j] != last_seq->seq_id[j]) {
+                            same = false;
+                        }
+                    }
+                    if (same) {
+                        last_seq->length += 1;
+                        continue;
+                    }
+                }
+                llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
+                seq.push_back(new_seq);
+                last_seq = &seq.back();
+            }
+        } else {
+            llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
+            seq.push_back(new_seq);
+        }
+        // keep shared prompts first at the end, then sort by length descending.
+        std::sort(seq.begin(), seq.end(),
+            [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
+                if (a.n_seq_id == b.n_seq_id) {
+                    return a.length > b.length;
+                }
+                return a.n_seq_id < b.n_seq_id;
+            }
+        );
+    }
+};
+
 struct llama_context {
     llama_context(const llama_model & model)
         : model(model)
@@ -2670,6 +3076,7 @@ struct llama_context {
 
     struct llama_cparams        cparams;
     struct llama_sampling       sampling;
+    struct llama_sbatch         sbatch;
     struct llama_kv_cache       kv_self;
     struct llama_control_vector cvec;
 
@@ -2930,8 +3337,7 @@ static bool llama_kv_cache_init(
 
     cache.has_shift = false;
 
-    // TODO: find a nicer way to add other recurrent model architectures
-    cache.recurrent = model.arch == LLM_ARCH_MAMBA;
+    cache.recurrent = llama_model_is_recurrent(&model);
     cache.v_trans   = !cache.recurrent && !cparams.flash_attn;
 
     cache.head = 0;
@@ -2944,13 +3350,6 @@ static bool llama_kv_cache_init(
     cache.cells.clear();
     cache.cells.resize(kv_size);
 
-    if (cache.recurrent) {
-        // init state copy sources
-        for (uint32_t i = 0; i < cache.size; ++i) {
-            cache.cells[i].src = i;
-        }
-    }
-
     // count used buffer types
     std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
     if (offload) {
@@ -3018,54 +3417,170 @@ static bool llama_kv_cache_init(
 // to the first cell of the slot.
 static bool llama_kv_cache_find_slot(
            struct llama_kv_cache & cache,
-        const struct llama_batch & batch) {
+       const struct llama_ubatch & batch) {
     const uint32_t n_tokens = batch.n_tokens;
+    const uint32_t n_seqs   = batch.n_seqs;
+    const uint32_t n_seq_tokens = batch.n_seq_tokens;
 
     if (cache.recurrent) {
         // For recurrent state architectures (like Mamba),
-        // each KV cache cell can store the state for a whole sequence.
-
-        llama_seq_id min = cache.size - 1;
-        llama_seq_id max = 0;
-
-        for (uint32_t i = 0; i < n_tokens; ++i) {
-            for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
-                llama_seq_id seq_id = batch.seq_id[i][j];
-                // make sure it's a valid seq_id
-                if ((uint32_t) seq_id < cache.size) {
-                    if (seq_id > max) {
-                        max = seq_id;
-                    }
-                    if (seq_id < min) {
-                        min = seq_id;
+        // each cache cell can store the state for a whole sequence.
+        // A slot should be always be contiguous.
+
+        // can only process batches with an equal number of new tokens in each sequence
+        GGML_ASSERT(batch.equal_seqs);
+
+        int32_t min = cache.size - 1;
+        int32_t max = 0;
+
+        // everything should fit if all seq_ids are smaller than the max
+        for (uint32_t s = 0; s < n_seqs; ++s) {
+            const uint32_t n_seq_id = batch.n_seq_id[s];
+            for (uint32_t j = 0; j < n_seq_id; ++j) {
+                const llama_seq_id seq_id = batch.seq_id[s][j];
+
+                if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
+                    // too big seq_id
+                    // TODO: would it be possible to resize the cache instead?
+                    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
+                    return false;
+                }
+                if (j > 0) {
+                    llama_kv_cell & seq = cache.cells[seq_id];
+                    if (seq.tail >= 0) {
+                        llama_kv_cell & cell = cache.cells[seq.tail];
+                        // clear cells from seq_ids that become shared
+                        // (should not normally happen, but let's handle it anyway)
+                        cell.seq_id.erase(seq_id);
+                        seq.tail = -1;
+                        if (cell.seq_id.empty()) {
+                            cell.pos = -1;
+                            cell.src = -1;
+                            cache.used -= 1;
+                        }
                     }
-                    // Assuming the tokens are in-order
-                    if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
-                        // What should happen when the pos backtracks or skips a value?
-                        // Clearing the state mid-batch would require special-casing which isn't done.
-                        LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
-                            __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
+                }
+            }
+        }
+
+#ifndef NDEBUG
+        {
+            std::vector<int32_t> tails_verif;
+            tails_verif.assign(cache.size, -1);
+            for (uint32_t i = 0; i < cache.size; ++i) {
+                llama_kv_cell & cell = cache.cells[i];
+                for (llama_seq_id seq_id : cell.seq_id) {
+                    if (tails_verif[seq_id] != -1) {
+                        LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
                     }
-                    if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
-                        cache.used += 1;
+                    tails_verif[seq_id] = i;
+                }
+            }
+            for (uint32_t i = 0; i < cache.size; ++i) {
+                if (tails_verif[i] != cache.cells[i].tail) {
+                    LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
+                }
+            }
+        }
+#endif
+
+        // find next empty cell
+        uint32_t next_empty_cell = cache.head;
+
+        for (uint32_t i = 0; i < cache.size; ++i) {
+            if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
+            llama_kv_cell & cell = cache.cells[next_empty_cell];
+            if (cell.is_empty()) { break; }
+            next_empty_cell += 1;
+        }
+
+        // find usable cell range
+        for (uint32_t s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = batch.seq_id[s][0];
+            llama_kv_cell & seq_meta = cache.cells[seq_id];
+            bool has_cell = false;
+            if (seq_meta.tail >= 0) {
+                llama_kv_cell & cell = cache.cells[seq_meta.tail];
+                GGML_ASSERT(cell.has_seq_id(seq_id));
+                // does this seq_id "own" the cell?
+                if (cell.seq_id.size() == 1) { has_cell = true; }
+            }
+            if (!has_cell) {
+                llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
+                GGML_ASSERT(empty_cell.is_empty());
+                // copy old tail into the empty cell
+                if (seq_meta.tail >= 0) {
+                    llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
+                    empty_cell.pos = orig_cell.pos;
+                    empty_cell.src = orig_cell.src;
+                    orig_cell.seq_id.erase(seq_id);
+                    empty_cell.seq_id.insert(seq_id); // will be overwritten
+                }
+                seq_meta.tail = next_empty_cell;
+                // find next empty cell
+                if (s + 1 < n_seqs) {
+                    next_empty_cell += 1;
+                    for (uint32_t i = 0; i < cache.size; ++i) {
+                        if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
+                        llama_kv_cell & cell = cache.cells[next_empty_cell];
+                        if (cell.is_empty()) { break; }
+                        next_empty_cell += 1;
                     }
-                    cache.cells[seq_id].pos = batch.pos[i];
-                    // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
-                } else {
-                    // too big seq_id
-                    // TODO: would it be possible to resize the KV cache size instead?
-                    LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
-                    return false;
+                }
+            }
+            if (min > seq_meta.tail) { min = seq_meta.tail; }
+            if (max < seq_meta.tail) { max = seq_meta.tail; }
+        }
+
+        // gather and re-order
+        for (uint32_t s = 0; s < n_seqs; ++s) {
+            int32_t dst_id = s + min;
+            int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
+            if (dst_id != src_id) {
+                llama_kv_cell & dst_cell = cache.cells[dst_id];
+                llama_kv_cell & src_cell = cache.cells[src_id];
+
+                std::swap(dst_cell.pos, src_cell.pos);
+                std::swap(dst_cell.src, src_cell.src);
+                std::swap(dst_cell.seq_id, src_cell.seq_id);
+
+                // swap tails (assuming they NEVER overlap)
+                for (const llama_seq_id seq_id : src_cell.seq_id) {
+                    cache.cells[seq_id].tail = src_id;
+                }
+                for (const llama_seq_id seq_id : dst_cell.seq_id) {
+                    cache.cells[seq_id].tail = dst_id;
                 }
             }
         }
 
+        // update the pos of the used seqs
+        for (uint32_t s = 0; s < n_seqs; ++s) {
+            const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
+            int32_t cell_id = s + min;
+            llama_kv_cell & cell = cache.cells[cell_id];
+
+            if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
+                // What should happen when the pos backtracks or skips a value?
+                // Clearing the state mid-batch would require special-casing which isn't done.
+                LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
+                    __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
+            }
+            cell.pos = last_pos;
+            cell.seq_id.clear();
+            for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
+                const llama_seq_id seq_id = batch.seq_id[s][j];
+                cell.seq_id.insert(seq_id);
+                cache.cells[seq_id].tail = cell_id;
+            }
+        }
+
         // allow getting the range of used cells, from head to head + n
         cache.head = min;
         cache.n    = max - min + 1;
 
         // sanity check
-        return max >= min;
+        return cache.n >= n_seqs;
     }
     // otherwise, one cell per token.
 
@@ -3103,11 +3618,14 @@ static bool llama_kv_cache_find_slot(
         }
     }
 
-    for (uint32_t i = 0; i < n_tokens; i++) {
-        cache.cells[cache.head + i].pos = batch.pos[i];
+    for (uint32_t s = 0; s < n_seqs; s++) {
+        for (uint32_t i = 0; i < n_seq_tokens; ++i) {
+            uint32_t k = s*n_seq_tokens + i;
+            cache.cells[cache.head + k].pos = batch.pos[k];
 
-        for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
-            cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
+            for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
+                cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
+            }
         }
     }
 
@@ -3133,6 +3651,8 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
     for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
         cache.cells[i].pos = -1;
         cache.cells[i].seq_id.clear();
+        cache.cells[i].src = -1;
+        cache.cells[i].tail = -1;
     }
     cache.head = 0;
     cache.used = 0;
@@ -3159,9 +3679,16 @@ static bool llama_kv_cache_seq_rm(
             return false;
         }
         if (0 <= seq_id) {
-            // partial intersection is invalid
-            if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
-                return false;
+            int32_t & tail_id = cache.cells[seq_id].tail;
+            if (tail_id >= 0) {
+                const llama_kv_cell & cell = cache.cells[tail_id];
+                // partial intersection is invalid
+                if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+                    return false;
+                }
+                if (p0 <= cell.pos && p1 < cell.pos) {
+                    tail_id = -1;
+                }
             }
         } else {
             // seq_id is negative, then the range should include everything or nothing
@@ -3185,6 +3712,7 @@ static bool llama_kv_cache_seq_rm(
                 if (cache.cells[i].pos >= 0) cache.used--;
 
                 cache.cells[i].pos = -1;
+                cache.cells[i].src = -1;
                 if (new_head == cache.size) new_head = i;
             }
         }
@@ -3207,23 +3735,29 @@ static void llama_kv_cache_seq_cp(
 
     if (cache.recurrent) {
         if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
-            seq_id_src = cache.cells[seq_id_src].src;
-            GGML_ASSERT((uint32_t) seq_id_src < cache.size);
-            // intent to "copy from"
-            // supports copy chains thanks to taking the source of the source
-            cache.cells[seq_id_dst].src = seq_id_src;
-
-            // preserve the "keep or clear" status of the copied sequence
-            if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
-                cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
-            } else {
-                cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
+            llama_kv_cell & tail_src = cache.cells[seq_id_src];
+            llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
+            if (tail_dst.tail >= 0) {
+                // clear destination seq_id if it wasn't empty
+                llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
+
+                cell_dst.seq_id.erase(seq_id_dst);
+                tail_dst.tail = -1;
+                if (cell_dst.seq_id.empty()) {
+                    cell_dst.pos = -1;
+                    cell_dst.delta = -1;
+                    cell_dst.src = -1;
+                    cache.used -= 1;
+                }
             }
+            if (tail_src.tail >= 0) {
+                llama_kv_cell & cell_src = cache.cells[tail_src.tail];
 
-            cache.do_copy = true;
-
-            cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
+                cell_src.seq_id.insert(seq_id_dst);
+                tail_dst.tail = tail_src.tail;
+            }
         }
+
         return;
     }
     // otherwise, this is the KV cache of a Transformer-like model
@@ -3241,9 +3775,13 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
     uint32_t new_head = cache.size;
 
     for (uint32_t i = 0; i < cache.size; ++i) {
+        if (cache.recurrent && (llama_seq_id) i != seq_id) {
+            cache.cells[i].tail = -1;
+        }
         if (!cache.cells[i].has_seq_id(seq_id)) {
             if (cache.cells[i].pos >= 0) cache.used--;
             cache.cells[i].pos = -1;
+            cache.cells[i].src = -1;
             cache.cells[i].seq_id.clear();
             if (new_head == cache.size) new_head = i;
         } else {
@@ -3272,9 +3810,12 @@ static void llama_kv_cache_seq_add(
     if (cache.recurrent) {
         // for Mamba-like models, only the pos needs to be shifted
         if (0 <= seq_id && seq_id < (int64_t) cache.size) {
-            llama_kv_cell & cell = cache.cells[seq_id];
-            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                cell.pos += delta;
+            const int32_t tail_id = cache.cells[seq_id].tail;
+            if (tail_id >= 0) {
+                llama_kv_cell & cell = cache.cells[tail_id];
+                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                    cell.pos += delta;
+                }
             }
         }
         return;
@@ -3318,9 +3859,12 @@ static void llama_kv_cache_seq_div(
     if (cache.recurrent) {
         // for Mamba-like models, only the pos needs to be changed
         if (0 <= seq_id && seq_id < (int64_t) cache.size) {
-            llama_kv_cell & cell = cache.cells[seq_id];
-            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                cell.pos /= d;
+            const int32_t tail_id = cache.cells[seq_id].tail;
+            if (tail_id >= 0) {
+                llama_kv_cell & cell = cache.cells[tail_id];
+                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                    cell.pos /= d;
+                }
             }
         }
         return;
@@ -3352,7 +3896,9 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
 }
 
 static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
-    cache.do_defrag = true;
+    if (!cache.recurrent) {
+        cache.do_defrag = true;
+    }
 }
 
 static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
@@ -3566,13 +4112,8 @@ namespace GGUFMeta {
 
 using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
 
-// TODO: update when needed or think of some clever automatic way to do this
-static size_t llama_model_max_nodes(const llama_model & /*model*/) {
-    //if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > ??) { // llama-3 405B
-    //    return 32768;
-    //}
-
-    return 8192;
+static size_t llama_model_max_nodes(const llama_model & model) {
+    return std::max<size_t>(8192, model.tensors_by_name.size()*5);
 }
 
 struct llama_model_loader {
@@ -4892,7 +5433,6 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_PHI3:
             {
-                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
                 switch (hparams.n_layer) {
@@ -4901,6 +5441,22 @@ static void llm_load_hparams(
                     case 40: model.type = e_model::MODEL_14B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
+
+                // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
+                if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
+                    // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
+                    hparams.n_swa = 2047;
+                } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
+                    // default value for Phi-3-mini-128k-instruct
+                    hparams.n_swa = 262144;
+                } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
+                    // default value for Phi-3-medium-128k-instruct
+                    hparams.n_swa = 131072;
+                }
+                bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+                if (!found_swa && hparams.n_swa == 0) {
+                    throw std::runtime_error("invalid value for sliding_window");
+                }
             } break;
         case LLM_ARCH_PLAMO:
             {
@@ -4992,6 +5548,7 @@ static void llm_load_hparams(
                 ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
                 ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
                 ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+                ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false);
 
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
@@ -5198,6 +5755,12 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                }
             } break;
+        case LLM_ARCH_T5ENCODER:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
+                model.type = e_model::MODEL_UNKNOWN;
+            } break;
         case LLM_ARCH_JAIS:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -5210,6 +5773,23 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_NEMOTRON:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_4B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
+        case LLM_ARCH_EXAONE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_8B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -5442,6 +6022,15 @@ static void llm_load_vocab(
             } else if (
                 tokenizer_pre == "codeshell") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
+            } else if (
+                tokenizer_pre == "bloom") {
+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BLOOM;
+            } else if (
+                tokenizer_pre == "gpt3-finnish") {
+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH;
+            } else if (
+                tokenizer_pre == "exaone") {
+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE;
             } else {
                 throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
             }
@@ -5815,6 +6404,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
         LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
         LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
+        LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
     }
 
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
@@ -6015,6 +6605,7 @@ static bool llm_load_tensors(
         const int64_t n_embd_gqa    = n_embd_v_gqa;
         const int64_t n_vocab       = hparams.n_vocab;
         const int64_t n_vocab_type  = hparams.n_vocab_type;
+        const int64_t n_rot         = hparams.n_rot;
         const int64_t n_expert      = hparams.n_expert;
         const int64_t n_expert_used = hparams.n_expert_used;
         const int64_t n_ctx_train   = hparams.n_ctx_train;
@@ -6072,7 +6663,7 @@ static bool llm_load_tensors(
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
-                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
 
                         if (n_expert == 0) {
                             layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
@@ -6080,9 +6671,9 @@ static bool llm_load_tensors(
                             layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
 
                             // optional MLP bias
-                            layer.ffn_gate_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_down_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_up_b   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         } else {
                             layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
 
@@ -6406,7 +6997,7 @@ static bool llm_load_tensors(
                         layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa});
 
                         layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens
-                        layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}); //output_dens
+                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}); //output_dens
 
                         layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
                         layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias",   i), {n_embd});
@@ -7432,6 +8023,42 @@ static bool llm_load_tensors(
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff});
                     }
                 } break;
+            case LLM_ARCH_T5ENCODER:
+                {
+                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
+
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm_enc  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd});
+                        layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+                        layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up_enc   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                    }
+                } break;
             case LLM_ARCH_JAIS:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -7489,8 +8116,8 @@ static bool llm_load_tensors(
 
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + (hparams.n_embd_head_k << 2)});
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
 
                         layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
 
@@ -7501,15 +8128,87 @@ static bool llm_load_tensors(
                         layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
                     }
                 } break;
-            default:
-                throw std::runtime_error("unknown architecture");
-        }
-    }
+            case LLM_ARCH_NEMOTRON:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
 
-    ml.done_getting_tensors();
+                    // output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,   tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split,  tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
+                    }
 
-    ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr);
-    model.mappings.reserve(ml.mappings.size());
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        // optional bias tensors
+                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
+
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+
+                        // optional MLP bias
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    }
+                } break;
+            case LLM_ARCH_EXAONE:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                    }
+                } break;
+            default:
+                throw std::runtime_error("unknown architecture");
+        }
+    }
+
+    ml.done_getting_tensors();
+
+    ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr);
+    model.mappings.reserve(ml.mappings.size());
 
     // create the backend buffers
     std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_bufs;
@@ -7742,7 +8441,7 @@ static struct ggml_tensor * llm_build_inp_embd(
         struct ggml_context * ctx,
        struct llama_context & lctx,
         const llama_hparams & hparams,
-          const llama_batch & batch,
+         const llama_ubatch & batch,
          struct ggml_tensor * tok_embd,
          const llm_build_cb & cb) {
     const int64_t n_embd = hparams.n_embd;
@@ -8176,9 +8875,10 @@ static struct ggml_tensor * llm_build_kqv(
                     0);
         cb(v, "v", il);
 
-        cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+        cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+                                  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
             ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
         }
 
@@ -8187,7 +8887,7 @@ static struct ggml_tensor * llm_build_kqv(
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
             // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
             // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
             ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -8291,12 +8991,180 @@ static struct ggml_tensor * llm_build_kv(
     return cur;
 }
 
+static struct ggml_tensor * llm_build_copy_mask_state(
+        struct ggml_context * ctx,
+         struct ggml_cgraph * graph,
+         struct ggml_tensor * s,
+         struct ggml_tensor * state_copy,
+         struct ggml_tensor * state_mask,
+                    int32_t   n_state,
+                    int32_t   kv_size,
+                    int32_t   kv_head,
+                    int32_t   n_kv,
+                    int32_t   n_seqs) {
+    struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size);
+
+    // copy states
+    // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
+    // this shrinks the tensors's ne[1] to n_kv
+    states = ggml_get_rows(ctx, states, state_copy);
+
+    // clear states of sequences which are starting at the beginning of this batch
+    // FIXME: zero-out NANs?
+    states = ggml_mul(ctx, states, state_mask);
+
+    // copy states which won't be changed further (between n_seqs and n_rs)
+    ggml_build_forward_expand(graph,
+        ggml_cpy(ctx,
+            ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)),
+            ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
+
+    // the part of the states that will be used and modified
+    return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
+}
+
+// TODO: split
+static struct ggml_tensor * llm_build_mamba(
+        struct ggml_context * ctx,
+       struct llama_context & lctx,
+         const llama_ubatch & batch,
+         struct ggml_cgraph * graph,
+         struct ggml_tensor * cur,
+         struct ggml_tensor * state_copy,
+         struct ggml_tensor * state_mask,
+                    int32_t   kv_head,
+                    int32_t   n_kv,
+         const llm_build_cb & cb,
+                    int       il) {
+    const llama_model    & model   = lctx.model;
+    const llama_hparams  & hparams = model.hparams;
+    const llama_kv_cache & kv      = lctx.kv_self;
+    const int64_t d_conv  = hparams.ssm_d_conv;
+    const int64_t d_inner = hparams.ssm_d_inner;
+    const int64_t d_state = hparams.ssm_d_state;
+    const int64_t dt_rank = hparams.ssm_dt_rank;
+    const int64_t n_seqs  = batch.n_seqs;
+    // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
+    const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
+    // Use the same RMS norm as the final layer norm
+    const float norm_rms_eps = hparams.f_norm_rms_eps;
+
+    const int64_t n_seq_tokens = batch.n_seq_tokens;
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(batch.equal_seqs);
+    GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
+
+    struct ggml_tensor * conv_states_all = kv.k_l[il];
+    struct ggml_tensor * ssm_states_all  = kv.v_l[il];
+
+    // (ab)using the KV cache to store the states
+    struct ggml_tensor * conv = llm_build_copy_mask_state(ctx,
+            graph, conv_states_all, state_copy, state_mask,
+            hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
+    conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs);
+    struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
+            graph, ssm_states_all, state_copy, state_mask,
+            hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
+    ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs);
+
+    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+    cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+    // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
+    struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur);
+    // split the above in two
+    // => {d_inner, n_seq_tokens, n_seqs}
+    struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
+    struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz));
+
+    // conv
+    {
+        // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
+        struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, x), 0);
+
+        // copy last (d_conv - 1) columns back into the state cache
+        struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
+
+        ggml_build_forward_expand(graph,
+            ggml_cpy(ctx, last_conv,
+                ggml_view_1d(ctx, conv_states_all,
+                    (d_conv - 1)*(d_inner)*(n_seqs),
+                    kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
+
+        // 1D convolution
+        // The equivalent is to make a self-overlapping view of conv_x
+        // over d_conv columns at each stride in the 3rd dimension,
+        // then element-wise multiply that with the conv1d weight,
+        // then sum the elements of each row,
+        // (the last two steps are a dot product over rows (also doable with mul_mat))
+        // then permute away the ne[0] dimension,
+        // and then you're left with the resulting x tensor.
+        // For simultaneous sequences, all sequences need to have the same length.
+        x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
+
+        // bias
+        x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
+
+        x = ggml_silu(ctx, x);
+    }
+
+    // ssm
+    {
+        // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
+        struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x);
+        // split
+        struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
+        struct ggml_tensor * B  = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
+        struct ggml_tensor * C  = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
+
+        // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
+        if (ssm_dt_b_c_rms) {
+            dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
+            B = ggml_rms_norm(ctx, B, norm_rms_eps);
+            C = ggml_rms_norm(ctx, C, norm_rms_eps);
+        }
+
+        // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
+        dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt);
+        dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
+
+        // Custom operator to optimize the parallel associative scan
+        // as described in the Annex D of the Mamba paper.
+        // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+        struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
+
+        // store last states
+        ggml_build_forward_expand(graph,
+            ggml_cpy(ctx,
+                ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
+                ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
+
+        struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
+
+        // TODO: skip computing output earlier for unused tokens
+
+        // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
+        y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d));
+        y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z)));
+
+        // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+        cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y);
+    }
+
+    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+    cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs);
+    cb(cur, "mamba_out", il);
+
+    return cur;
+}
+
 struct llm_build_context {
     const llama_model    & model;
           llama_context  & lctx;
     const llama_hparams  & hparams;
     const llama_cparams  & cparams;
-    const llama_batch    & batch;
+    const llama_ubatch   & batch;
     const llama_kv_cache & kv_self;
 
     const int64_t n_embd;
@@ -8342,7 +9210,7 @@ struct llm_build_context {
     // TODO: consider making the entire interface noexcept
     llm_build_context(
         llama_context  & lctx,
-    const llama_batch  & batch,
+    const llama_ubatch & batch,
     const llm_build_cb & cb,
                   bool   worst_case) :
         model            (lctx.model),
@@ -8449,29 +9317,6 @@ struct llm_build_context {
         return gf;
     }
 
-    struct ggml_cgraph * build_s_copy() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
-
-        GGML_ASSERT(kv_self.recurrent);
-
-        struct ggml_tensor * state_copy = build_inp_s_copy();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
-            struct ggml_tensor * ssm_states  = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
-
-            conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
-            ssm_states  = ggml_get_rows(ctx0,  ssm_states, state_copy);
-
-            // TODO: name the intermediate tensors with cb()
-
-            ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
-            ggml_build_forward_expand(gf, ggml_cpy(ctx0,  ssm_states, kv_self.v_l[il]));
-        }
-
-        return gf;
-    }
-
     struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
@@ -8606,7 +9451,7 @@ struct llm_build_context {
     }
 
     struct ggml_tensor * build_inp_s_copy() {
-        lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
+        lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
         cb(lctx.inp_s_copy, "inp_s_copy", -1);
         ggml_set_input(lctx.inp_s_copy);
         return lctx.inp_s_copy;
@@ -8619,13 +9464,6 @@ struct llm_build_context {
         return lctx.inp_s_mask;
     }
 
-    struct ggml_tensor * build_inp_s_seq() {
-        lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
-        cb(lctx.inp_s_seq, "inp_s_seq", -1);
-        ggml_set_input(lctx.inp_s_seq);
-        return lctx.inp_s_seq;
-    }
-
     struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
         // find result_norm tensor for input
         struct ggml_tensor * inp = nullptr;
@@ -11955,125 +12793,31 @@ struct llm_build_context {
     struct ggml_cgraph * build_mamba() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
-        const int64_t d_model = n_embd;
-        const int64_t d_conv  = hparams.ssm_d_conv;
-        const int64_t d_inner = hparams.ssm_d_inner;
-        GGML_ASSERT(2 * d_model == d_inner);
-        const int64_t d_state = hparams.ssm_d_state;
-        const int64_t dt_rank = hparams.ssm_dt_rank;
-
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
         // {n_embd, n_tokens}
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
+        struct ggml_tensor * state_copy = build_inp_s_copy();
         struct ggml_tensor * state_mask = build_inp_s_mask();
-        struct ggml_tensor * state_seq  = build_inp_s_seq();
 
         for (int il = 0; il < n_layer; ++il) {
-            // (ab)using the KV cache to store the states
-            struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
-            struct ggml_tensor * ssm_states  = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
-
-            // clear states of sequences which are starting at the beginning of this batch
-            {
-                conv_states = ggml_mul(ctx0,
-                    ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
-                    state_mask);
-                ssm_states  = ggml_mul(ctx0,
-                    ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
-                    state_mask);
-            }
-
-            conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
-            ssm_states  = ggml_reshape_3d(ctx0,  ssm_states,    d_state, d_inner, n_kv);
-
             // norm
             cur = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm, NULL,
                     LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
 
-            // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
-            struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur);
-            // split the above in two
-            // => {d_inner, n_tokens}
-            struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
-            struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
-
-            // conv
-            {
-                // Custom operator which is needed only to ease simultaneous sequence processing.
-                // For a single sequence, the equivalent is to concatenate the columns of conv_states and x,
-                // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
-                // then element-wise multiply that with the conv1d weigth,
-                // then sum the elements of each row,
-                // (the last two steps are a dot product over rows (also doable with mul_mat))
-                // then permute away the ne[0] dimension,
-                // and then you're left with the resulting x tensor.
-                // The new conv_states is the last (d_conv - 1) columns
-                // of the last 3rd dimensional "layer" of the self-overlapping view.
-                // For simultaneous sequences, it's more complicated.
-                struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
-
-                // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
-                ggml_build_forward_expand(gf,
-                    ggml_cpy(ctx0,
-                        ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
-                        ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
-
-                // extract x from x_conv
-                x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
-
-                // bias
-                x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
-
-                x = ggml_silu(ctx0, x);
-            }
-
-            // ssm
-            {
-                // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
-                struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x);
-                // split
-                struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
-                struct ggml_tensor * B  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
-                struct ggml_tensor * C  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
-
-                // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
-                dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
-                dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
-
-                // Custom operator to optimize the parallel associative scan
-                // as described in the Annex D of the Mamba paper.
-                // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
-                // because only a single tensor can be returned.
-                struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq);
-
-                // store last states (the second part of y_ssm_states)
-                ggml_build_forward_expand(gf,
-                    ggml_cpy(ctx0,
-                        ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
-                        ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states))));
-
-                struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
-
-                if (il == n_layer - 1) {
-                    // skip computing output for unused tokens
-                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                    x    = ggml_get_rows(ctx0,    x, inp_out_ids);
-                    y    = ggml_get_rows(ctx0,    y, inp_out_ids);
-                    z    = ggml_get_rows(ctx0,    z, inp_out_ids);
-                    inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-                }
-
-                // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
-                y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
-                y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
+            cur = llm_build_mamba(ctx0, lctx, batch, gf, cur,
+                    state_copy, state_mask,
+                    kv_head, n_kv, cb, il);
 
-                // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y);
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
             }
 
             // residual
@@ -13146,7 +13890,7 @@ struct llm_build_context {
         return gf;
     }
 
-    struct ggml_cgraph * build_t5() {
+    struct ggml_cgraph * build_t5_encoder() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
@@ -13161,303 +13905,323 @@ struct llm_build_context {
 
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
-        if (lctx.is_encoding) {
-            struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
+        GGML_ASSERT(lctx.is_encoding);
+        struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
 
-            // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-            struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
-
-            for (int il = 0; il < n_layer; ++il) {
-                struct ggml_tensor * inpSA = inpL;
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
 
-                // norm
-                cur = llm_build_norm(ctx0, inpL, hparams,
-                        model.layers[il].attn_norm_enc, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_norm", il);
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
 
-                // self-attention
-                {
-                    struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq_enc, cur);
-                    cb(Qcur, "Qcur", il);
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm_enc, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
 
-                    struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk_enc, cur);
-                    cb(Kcur, "Kcur", il);
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur);
+                cb(Qcur, "Qcur", il);
 
-                    struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_enc, cur);
-                    cb(Vcur, "Vcur", il);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur);
+                cb(Kcur, "Kcur", il);
 
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur);
+                cb(Vcur, "Vcur", il);
 
-                    struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-                    struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
-                    struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                    cb(kq, "kq", il);
+                struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
 
-                    struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
-                    struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
-                    struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
-                    cb(kq_b, "kq_b", il);
+                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                cb(kq, "kq", il);
 
-                    kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
-                    cb(kq, "kq_soft_max_ext", il);
+                struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
+                struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
+                struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
+                cb(kq_b, "kq_b", il);
 
-                    struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
-                    cb(v, "v", il);
+                kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
+                cb(kq, "kq_soft_max_ext", il);
 
-                    struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
-                    cb(kqv, "kqv", il);
+                struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
+                cb(v, "v", il);
 
-                    struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                    cb(kqv_merged, "kqv_merged", il);
+                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
+                cb(kqv, "kqv", il);
 
-                    cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                    cb(cur, "kqv_merged_cont", il);
+                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                cb(kqv_merged, "kqv_merged", il);
 
-                    ggml_build_forward_expand(gf, cur);
+                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                cb(cur, "kqv_merged_cont", il);
 
-                    cur = ggml_mul_mat(ctx0, model.layers[il].wo_enc, cur);
-                    cb(cur, "kqv_out", il);
-                }
+                ggml_build_forward_expand(gf, cur);
 
-                if (il == n_layer - 1) {
-                    // skip computing output for unused tokens
-                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                    n_tokens = n_outputs;
-                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-                }
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur);
+                cb(cur, "kqv_out", il);
+            }
 
-                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-                cb(ffn_inp, "ffn_inp", il);
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
 
-                // feed-forward network
-                {
-                    cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                            model.layers[il].ffn_norm_enc, NULL,
-                            LLM_NORM_RMS, cb, il);
-                    cb(cur, "ffn_norm", il);
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
 
-                    // T5 uses relu, flan-T5 uses gelu-gated
-                    cur = llm_build_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_up_enc,   NULL, NULL,
-                            model.layers[il].ffn_gate_enc, NULL, NULL,
-                            model.layers[il].ffn_down_enc, NULL, NULL,
-                            NULL,
-                            model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
-                            model.layers[il].ffn_gate_enc ? LLM_FFN_PAR  : LLM_FFN_SEQ,
-                            cb, il);
-                    cb(cur, "ffn_out", il);
-                }
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm_enc, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
 
-                cur = ggml_add(ctx0, cur, ffn_inp);
+                // T5 uses relu, flan-T5 uses gelu-gated
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up_enc,   NULL, NULL,
+                        model.layers[il].ffn_gate_enc, NULL, NULL,
+                        model.layers[il].ffn_down_enc, NULL, NULL,
+                        NULL,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR  : LLM_FFN_SEQ,
+                        cb, il);
                 cb(cur, "ffn_out", il);
+            }
 
-                ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-                if (layer_dir != nullptr) {
-                    cur = ggml_add(ctx0, cur, layer_dir);
-                }
-                cb(cur, "l_out", il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
 
-                // input for next layer
-                inpL = cur;
+            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+            if (layer_dir != nullptr) {
+                cur = ggml_add(ctx0, cur, layer_dir);
             }
+            cb(cur, "l_out", il);
 
-            cur = inpL;
-            cb(cur, "result_embd", -1);
+            // input for next layer
+            inpL = cur;
+        }
 
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.output_norm_enc, NULL,
-                    LLM_NORM_RMS, cb, -1);
-            cb(cur, "result_norm", -1);
-        } else {
-            GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
+        cur = inpL;
+        cb(cur, "result_embd", -1);
 
-            struct ggml_tensor * embd_enc       = llm_build_inp_embd_enc();
-            struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm_enc, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
 
-            struct ggml_tensor * KQ_mask_dec   = build_inp_KQ_mask();
-            struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
+        ggml_build_forward_expand(gf, cur);
 
-            for (int il = 0; il < n_layer; ++il) {
-                struct ggml_tensor * inpSA = inpL;
+        return gf;
+    }
 
-                // norm
-                cur = llm_build_norm(ctx0, inpL, hparams,
-                        model.layers[il].attn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_norm", il);
+    struct ggml_cgraph * build_t5_decoder() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
-                // self-attention
-                {
-                    struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
-                    cb(Qcur, "Qcur", il);
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
 
-                    struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
-                    cb(Kcur, "Kcur", il);
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
-                    struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
-                    cb(Vcur, "Vcur", il);
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
 
-                    llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
 
-                    struct ggml_tensor * k =
-                        ggml_view_3d(ctx0, kv_self.k_l[il],
-                                n_embd_head_k, n_kv, n_head_kv,
-                                ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                                ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
-                                0);
-                    cb(k, "k", il);
+        GGML_ASSERT(!lctx.is_encoding);
+        GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
 
-                    struct ggml_tensor * v =
-                        ggml_view_3d(ctx0, kv_self.v_l[il],
-                                n_kv, n_embd_head_v, n_head_kv,
-                                ggml_element_size(kv_self.v_l[il])*n_ctx,
-                                ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
-                                0);
-                    cb(v, "v", il);
+        struct ggml_tensor * embd_enc       = llm_build_inp_embd_enc();
+        struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
 
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+        struct ggml_tensor * KQ_mask_dec   = build_inp_KQ_mask();
+        struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
 
-                    struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
 
-                    struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                    cb(kq, "kq", il);
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
 
-                    struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
-                    struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
-                    struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
-                    cb(kq_b, "kq_b", il);
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
 
-                    kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
-                    cb(kq, "kq_soft_max_ext", il);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
 
-                    struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
-                    cb(kqv, "kqv", il);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
 
-                    struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                    cb(kqv_merged, "kqv_merged", il);
+                llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
 
-                    cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                    cb(cur, "kqv_merged_cont", il);
+                struct ggml_tensor * k =
+                    ggml_view_3d(ctx0, kv_self.k_l[il],
+                            n_embd_head_k, n_kv, n_head_kv,
+                            ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                            ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+                            0);
+                cb(k, "k", il);
 
-                    ggml_build_forward_expand(gf, cur);
+                struct ggml_tensor * v =
+                    ggml_view_3d(ctx0, kv_self.v_l[il],
+                            n_kv, n_embd_head_v, n_head_kv,
+                            ggml_element_size(kv_self.v_l[il])*n_ctx,
+                            ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
+                            0);
+                cb(v, "v", il);
 
-                    cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
-                    cb(cur, "kqv_out", il);
-                }
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = ggml_add(ctx0, cur, inpSA);
-                cb(cur, "cross_inp", il);
+                struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
 
-                struct ggml_tensor * inpCA = cur;
+                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                cb(kq, "kq", il);
 
-                // norm
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].attn_norm_cross, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_norm_cross", il);
+                struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
+                struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
+                struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
+                cb(kq_b, "kq_b", il);
 
-                // cross-attention
-                {
-                    struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq_cross, cur);
-                    cb(Qcur, "Qcur", il);
+                kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
+                cb(kq, "kq_soft_max_ext", il);
 
-                    struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk_cross, embd_enc);
-                    cb(Kcur, "Kcur", il);
+                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+                cb(kqv, "kqv", il);
 
-                    struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_cross, embd_enc);
-                    cb(Vcur, "Vcur", il);
+                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                cb(kqv_merged, "kqv_merged", il);
 
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
+                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                cb(cur, "kqv_merged_cont", il);
 
-                    struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-                    struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+                ggml_build_forward_expand(gf, cur);
 
-                    struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                    cb(kq, "kq", il);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
+                cb(cur, "kqv_out", il);
+            }
 
-                    kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
-                    cb(kq, "kq_soft_max_ext", il);
+            cur = ggml_add(ctx0, cur, inpSA);
+            cb(cur, "cross_inp", il);
 
-                    struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
-                    cb(v, "v", il);
+            struct ggml_tensor * inpCA = cur;
 
-                    struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
-                    cb(kqv, "kqv", il);
+            // norm
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.layers[il].attn_norm_cross, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm_cross", il);
 
-                    struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                    cb(kqv_merged, "kqv_merged", il);
+            // cross-attention
+            {
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur);
+                cb(Qcur, "Qcur", il);
 
-                    cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                    cb(cur, "kqv_merged_cont", il);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc);
+                cb(Kcur, "Kcur", il);
 
-                    ggml_build_forward_expand(gf, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc);
+                cb(Vcur, "Vcur", il);
 
-                    cur = ggml_mul_mat(ctx0, model.layers[il].wo_cross, cur);
-                    cb(cur, "kqv_out", il);
-                }
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
 
-                if (il == n_layer - 1) {
-                    // skip computing output for unused tokens
-                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                    n_tokens = n_outputs;
-                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-                    inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
-                }
+                struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
 
-                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
-                cb(ffn_inp, "ffn_inp", il);
+                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                cb(kq, "kq", il);
 
-                // feed-forward network
-                {
-                    cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                            model.layers[il].ffn_norm, NULL,
-                            LLM_NORM_RMS, cb, il);
-                    cb(cur, "ffn_norm", il);
+                kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
+                cb(kq, "kq_soft_max_ext", il);
 
-                    // T5 uses relu, flan-T5 uses gelu-gated
-                    cur = llm_build_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_up,   NULL, NULL,
-                            model.layers[il].ffn_gate, NULL, NULL,
-                            model.layers[il].ffn_down, NULL, NULL,
-                            NULL,
-                            model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
-                            model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
-                            cb, il);
-                    cb(cur, "ffn_out", il);
-                }
+                struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
+                cb(v, "v", il);
 
-                cur = ggml_add(ctx0, cur, ffn_inp);
-                cb(cur, "ffn_out", il);
+                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
+                cb(kqv, "kqv", il);
 
-                ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-                if (layer_dir != nullptr) {
-                    cur = ggml_add(ctx0, cur, layer_dir);
-                }
-                cb(cur, "l_out", il);
+                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                cb(kqv_merged, "kqv_merged", il);
 
-                // input for next layer
-                inpL = cur;
+                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                cb(cur, "kqv_merged_cont", il);
+
+                ggml_build_forward_expand(gf, cur);
+
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur);
+                cb(cur, "kqv_out", il);
             }
 
-            cur = inpL;
-            cb(cur, "result_embd", -1);
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
+            }
 
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.output_norm, NULL,
-                    LLM_NORM_RMS, cb, -1);
-            cb(cur, "result_norm", -1);
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
+            cb(ffn_inp, "ffn_inp", il);
 
-            // lm_head
-            cur = ggml_mul_mat(ctx0, model.output, cur);
-            cb(cur, "result_output", -1);
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                // T5 uses relu, flan-T5 uses gelu-gated
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
+                        cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+            if (layer_dir != nullptr) {
+                cur = ggml_add(ctx0, cur, layer_dir);
+            }
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
 
+        cur = inpL;
+        cb(cur, "result_embd", -1);
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
         ggml_build_forward_expand(gf, cur);
 
         return gf;
@@ -13668,11 +14432,259 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_nemotron() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        //GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm,
+                    model.layers[il].ffn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                    NULL,                      NULL,                        NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
+                    LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct ggml_cgraph * build_exaone() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
-    llama_batch dummy;
-    dummy.n_tokens = 0;
+    llama_ubatch dummy = {};
+    dummy.equal_seqs = true;
 
     llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
 
@@ -13688,8 +14700,8 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
 }
 
 static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
-    llama_batch dummy;
-    dummy.n_tokens = 0;
+    llama_ubatch dummy = {};
+    dummy.equal_seqs = true;
 
     llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
 
@@ -13704,26 +14716,9 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
     return result;
 }
 
-static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
-    llama_batch dummy;
-    dummy.n_tokens = 0;
-
-    llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
-
-    struct llm_build_context llm(lctx, dummy, cb, false);
-
-    llm.init();
-
-    struct ggml_cgraph * result = llm.build_s_copy();
-
-    llm.free();
-
-    return result;
-}
-
 static struct ggml_cgraph * llama_build_graph(
          llama_context & lctx,
-     const llama_batch & batch,
+    const llama_ubatch & batch,
                   bool   worst_case) {
     const auto & model = lctx.model;
 
@@ -13909,12 +14904,28 @@ static struct ggml_cgraph * llama_build_graph(
             } break;
         case LLM_ARCH_T5:
             {
-                result = llm.build_t5();
+                if (lctx.is_encoding) {
+                    result = llm.build_t5_encoder();
+                } else {
+                    result = llm.build_t5_decoder();
+                }
+            } break;
+        case LLM_ARCH_T5ENCODER:
+            {
+                result = llm.build_t5_encoder();
             } break;
         case LLM_ARCH_JAIS:
             {
                 result = llm.build_jais();
             } break;
+        case LLM_ARCH_NEMOTRON:
+            {
+                result = llm.build_nemotron();
+            } break;
+        case LLM_ARCH_EXAONE:
+            {
+                result = llm.build_exaone();
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -13977,7 +14988,7 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
     return relative_bucket;
 }
 
-static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
+static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     //
     // set input data
     //
@@ -14016,10 +15027,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             for (int i = 0; i < n_tokens; ++i) {
                 data[i] = i;
             }
-        } else if (batch.logits) {
+        } else if (batch.output) {
             int32_t n_outputs = 0;
             for (int i = 0; i < n_tokens; ++i) {
-                if (batch.logits[i]) {
+                if (batch.output[i]) {
                     data[n_outputs++] = i;
                 }
             }
@@ -14043,8 +15054,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
         // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
         if (cparams.causal_attn && !lctx.is_encoding) {
-            const int64_t n_kv     = kv_self.n;
-            const int64_t n_tokens = batch.n_tokens;
+            const int64_t n_kv         = kv_self.n;
+            const int64_t n_tokens     = batch.n_tokens;
+            const int64_t n_seq_tokens = batch.n_seq_tokens;
+            const int64_t n_seqs       = batch.n_seqs;
 
 
             float * data     = nullptr;
@@ -14064,32 +15077,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             // of the correct sequence for each token of the batch.
             // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
             for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    const llama_pos    pos    = batch.pos[j];
-                    const llama_seq_id seq_id = batch.seq_id[j][0];
+                for (int s = 0; s < n_seqs; ++s) {
+                    const llama_seq_id seq_id = batch.seq_id[s][0];
 
-                    for (int i = 0; i < n_kv; ++i) {
-                        float f;
-                        if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
-                            f = -INFINITY;
-                        } else {
-                            if (hparams.use_alibi) {
-                                f = -std::abs(lctx.kv_self.cells[i].pos - pos);
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const llama_pos pos = batch.pos[s*n_seq_tokens + j];
+
+                        for (int i = 0; i < n_kv; ++i) {
+                            float f;
+                            if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+                                f = -INFINITY;
                             } else {
-                                f = 0.0f;
+                                if (hparams.use_alibi) {
+                                    f = -std::abs(kv_self.cells[i].pos - pos);
+                                } else {
+                                    f = 0.0f;
+                                }
                             }
-                        }
 
-                        if (data) {
-                            data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
-                        }
+                            if (data) {
+                                data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+                            }
 
-                        // may need to cut off old tokens for sliding window
-                        if (data_swa) {
-                            if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
-                                f = -INFINITY;
+                            // may need to cut off old tokens for sliding window
+                            if (data_swa) {
+                                if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
+                                    f = -INFINITY;
+                                }
+                                data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
                             }
-                            data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
                         }
                     }
                 }
@@ -14111,8 +15127,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                 }
             }
         } else {
+            const int64_t n_tokens     = batch.n_tokens;
+            const int64_t n_seq_tokens = batch.n_seq_tokens;
+            const int64_t n_seqs       = batch.n_seqs;
             // when using kv cache, the mask needs to match the kv cache size
-            const int64_t n_tokens = batch.n_tokens;
             const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
 
             GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
@@ -14120,27 +15138,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             float * data = (float *) lctx.inp_KQ_mask->data;
 
             for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    const llama_seq_id seq_id = batch.seq_id[j][0];
-
-                    for (int i = 0; i < n_tokens; ++i) {
-                        float f = -INFINITY;
-                        for (int s = 0; s < batch.n_seq_id[i]; ++s) {
-                            if (batch.seq_id[i][s] == seq_id) {
-                                if (hparams.use_alibi) {
-                                    f = -std::abs(batch.pos[i] - batch.pos[j]);
-                                } else {
-                                    f = 0.0f;
+                for (int s1 = 0; s1 < n_seqs; ++s1) {
+                    const llama_seq_id seq_id = batch.seq_id[s1][0];
+
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const int32_t tj = s1*n_seq_tokens + j;
+
+                        for (int s0 = 0; s0 < n_seqs; ++s0) {
+                            for (int i = 0; i < n_seq_tokens; ++i) {
+                                const int32_t ti = s0*n_seq_tokens + i;
+                                float f = -INFINITY;
+
+                                for (int s = 0; s < batch.n_seq_id[s0]; ++s) {
+                                    if (batch.seq_id[s0][s] == seq_id) {
+                                        if (hparams.use_alibi) {
+                                            f = -std::abs(batch.pos[ti] - batch.pos[tj]);
+                                        } else {
+                                            f = 0.0f;
+                                        }
+                                        break;
+                                    }
                                 }
-                                break;
+
+                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
                             }
                         }
 
-                        data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
-                    }
-
-                    for (int i = n_tokens; i < n_stride; ++i) {
-                        data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
+                        for (int i = n_tokens; i < n_stride; ++i) {
+                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
+                        }
                     }
                 }
             }
@@ -14148,7 +15174,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens     = batch.n_tokens;
+        const int64_t n_seq_tokens = batch.n_seq_tokens;
+        const int64_t n_seqs       = batch.n_seqs;
 
         GGML_ASSERT(lctx.inp_mean);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@@ -14157,12 +15185,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
 
         std::vector<uint64_t> sum(n_tokens, 0);
-        for (int i = 0; i < n_tokens; ++i) {
-            const llama_seq_id seq_id = batch.seq_id[i][0];
 
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = batch.seq_id[s][0];
+
+            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
 
-            sum[seq_id] += 1;
+            sum[seq_id] += batch.n_seq_tokens;
         }
 
         std::vector<float> div(n_tokens, 0.0f);
@@ -14173,14 +15203,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             }
         }
 
-        for (int i = 0; i < n_tokens; ++i) {
-            const llama_seq_id seq_id = batch.seq_id[i][0];
-            data[seq_id*n_tokens + i] = div[seq_id];
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = batch.seq_id[s][0];
+
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
+            }
         }
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens     = batch.n_tokens;
+        const int64_t n_seq_tokens = batch.n_seq_tokens;
+        const int64_t n_seqs       = batch.n_seqs;
 
         GGML_ASSERT(lctx.inp_cls);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -14188,20 +15223,26 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         uint32_t * data = (uint32_t *) lctx.inp_cls->data;
         memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
 
-        for (int i = 0; i < n_tokens; ++i) {
-            const llama_seq_id seq_id = batch.seq_id[i][0];
-            const llama_pos    pos    = batch.pos[i];
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = batch.seq_id[s][0];
 
+            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
 
-            if (pos == 0) {
-                data[seq_id] = i;
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                const llama_pos pos = batch.pos[s*n_seq_tokens + i];
+
+                if (pos == 0) {
+                    data[seq_id] = s*n_seq_tokens + i;
+                }
             }
         }
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens     = batch.n_tokens;
+        const int64_t n_seq_tokens = batch.n_seq_tokens;
+        const int64_t n_seqs       = batch.n_seqs;
 
         GGML_ASSERT(lctx.inp_cls);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -14212,15 +15253,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         std::vector<int> last_pos(n_tokens, -1);
         std::vector<int> last_row(n_tokens, -1);
 
-        for (int i = 0; i < n_tokens; ++i) {
-            const llama_seq_id seq_id = batch.seq_id[i][0];
-            const llama_pos    pos    = batch.pos[i];
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = batch.seq_id[s][0];
 
+            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
 
-            if (pos >= last_pos[seq_id]) {
-                last_pos[seq_id] = pos;
-                last_row[seq_id] = i;
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                const llama_pos pos = batch.pos[s*n_seq_tokens + i];
+
+                if (pos >= last_pos[seq_id]) {
+                    last_pos[seq_id] = pos;
+                    last_row[seq_id] = s*n_seq_tokens + i;
+                }
             }
         }
 
@@ -14238,41 +15283,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
             float * data = (float *) lctx.inp_s_mask->data;
 
-            // states which are not affected by the current batch are left untouched
+            // clear unused states
             for (int i = 0; i < n_kv; ++i) {
-                llama_seq_id    seq_id       = i + lctx.kv_self.head;
-                llama_kv_cell & kv_cell      = lctx.kv_self.cells[seq_id];
-                bool            has_self_seq = kv_cell.has_seq_id(seq_id);
+                uint32_t        cell_id = i + kv_self.head;
+                llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
 
-                data[i] = (float) has_self_seq;
+                data[i] = (float) (kv_cell.src >= 0);
 
-                // ensure current sequences will be kept
-                if (!has_self_seq && kv_cell.pos >= 0) {
-                    kv_cell.seq_id.insert(seq_id);
+                // only clear once
+                if (kv_cell.src < 0) {
+                    kv_cell.src = cell_id;
                 }
             }
         }
-        // For Mamba (and other recurrent architectures),
-        // update the correct state(s)/sequence(s) for each token of the batch.
-        // Like with the KQ_mask, if a token in the batch has multiple sequences,
-        // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
-        if (lctx.inp_s_seq) {
-            const int64_t n_tokens = batch.n_tokens;
 
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
-            int32_t * data = (int32_t *) lctx.inp_s_seq->data;
+        if (lctx.inp_s_copy) {
+            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
+            int32_t * data = (int32_t *) lctx.inp_s_copy->data;
 
-            for (int j = 0; j < n_tokens; ++j) {
-                const int32_t n_seq = batch.n_seq_id[j];
-                GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
+            // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+            for (uint32_t i = 0; i < n_kv; ++i) {
+                const uint32_t  cell_id = i + kv_self.head;
+                llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
 
-                for (int i = 0; i < n_kv; ++i) {
-                    if (i < n_seq) {
-                        // for this type of model, the head is the minimum seq_id of the batch
-                        data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head;
-                    } else {
-                        data[j*n_kv + i] = -1;
-                    }
+                // prevent out-of-bound sources
+                if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
+                    kv_cell.src = cell_id;
+                }
+
+                data[i] = kv_cell.src;
+
+                // ensure copy only happens once
+                if (kv_cell.src != (int32_t) cell_id) {
+                    kv_cell.src = cell_id;
                 }
             }
         }
@@ -14282,6 +15325,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
+        GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
 
         int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
 
@@ -14317,6 +15361,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
+        GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
 
         float * data = (float *) lctx.inp_KQ_mask_cross->data;
 
@@ -14357,7 +15402,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
 
     // TODO: use a per-batch flag for logits presence instead
     const bool has_logits = !cparams.embeddings;
-    const bool has_embd   =  lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
+    const bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
 
     const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
     const size_t embd_size   = has_embd   ?  n_embd*n_outputs_max : 0;
@@ -14410,6 +15455,43 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
     return n_outputs_max;
 }
 
+// make the outputs have the same order they had in the user-provided batch
+static void llama_output_reorder(struct llama_context * ctx) {
+    std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
+    if (!out_ids.empty()) {
+        uint32_t n_vocab = ctx->model.hparams.n_vocab;
+        uint32_t n_embd  = ctx->model.hparams.n_embd;
+        int32_t n_outputs = ctx->n_outputs;
+        GGML_ASSERT((size_t) n_outputs == out_ids.size());
+        // TODO: is there something more efficient which also minimizes swaps?
+        // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
+        for (int32_t i = 0; i < n_outputs - 1; ++i) {
+            int32_t j_min = i;
+            for (int32_t j = i + 1; j < n_outputs; ++j) {
+                if (out_ids[j] < out_ids[j_min]) {
+                    j_min = j;
+                }
+            }
+            if (j_min == i) { continue; }
+            std::swap(out_ids[i], out_ids[j_min]);
+            if (ctx->logits_size > 0) {
+                for (uint32_t k = 0; k < n_vocab; k++) {
+                    std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
+                }
+            }
+            if (ctx->embd_size > 0) {
+                for (uint32_t k = 0; k < n_embd; k++) {
+                    std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
+                }
+            }
+        }
+        std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
+        for (int32_t i = 0; i < n_outputs; ++i) {
+            ctx->output_ids[out_ids[i]] = i;
+        }
+        out_ids.clear();
+    }
+}
 
 static void llama_graph_compute(
         llama_context & lctx,
@@ -14482,15 +15564,11 @@ static int llama_decode_internal(
 
     const auto n_ubatch = cparams.n_ubatch;
 
-    // TODO: simplify or deprecate
-    std::vector<llama_pos> pos;
-    std::vector<int32_t>                   n_seq_id;
-    std::vector<llama_seq_id *>            seq_id_arr;
-    std::vector<std::vector<llama_seq_id>> seq_id;
-
     // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
     const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
 
+    lctx.embd_seq.clear();
+
     // count outputs
     if (batch_all.logits && !embd_pooled) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
@@ -14503,55 +15581,42 @@ static int llama_decode_internal(
         n_outputs = 1;
     }
 
+    lctx.sbatch.from_batch(batch_all, n_embd,
+        /* simple_split */ !kv_self.recurrent,
+        /* logits_all   */ n_outputs == n_tokens_all);
+
     // reserve output buffer
     if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
         LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
         return -2;
     };
 
-    // set output mappings
-    if (batch_all.logits) {
-        int32_t i_logits = 0;
-        for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            if (batch_all.logits[i]) {
-                lctx.output_ids[i] = i_logits++;
+    while (lctx.sbatch.n_tokens > 0) {
+        llama_ubatch ubatch;
+        if (kv_self.recurrent) {
+            if (embd_pooled) {
+                // Pooled embeddings cannot be split across ubatches (yet)
+                ubatch = lctx.sbatch.split_seq(n_ubatch);
+            } else {
+                // recurrent model architectures are easier to implement
+                // with equal-length sequences
+                ubatch = lctx.sbatch.split_equal(n_ubatch);
             }
+        } else {
+            ubatch = lctx.sbatch.split_simple(n_ubatch);
         }
-    } else {
-        for (uint32_t i = 0; i < n_outputs; ++i) {
-            lctx.output_ids[i] = i;
-        }
-    }
-
-    for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
-        const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
-        llama_batch u_batch = {
-            /* .n_tokens   = */ (int32_t) n_tokens,
-            /* .token      = */ batch_all.token     ? batch_all.token    + cur_token        : nullptr,
-            /* .embd       = */ batch_all.embd      ? batch_all.embd     + cur_token*n_embd : nullptr,
-            /* .pos        = */ batch_all.pos       ? batch_all.pos      + cur_token        : nullptr,
-            /* .n_seq_id   = */ batch_all.n_seq_id  ? batch_all.n_seq_id + cur_token        : nullptr,
-            /* .seq_id     = */ batch_all.seq_id    ? batch_all.seq_id   + cur_token        : nullptr,
-            /* .logits     = */ batch_all.logits    ? batch_all.logits   + cur_token        : nullptr,
-            /* .all_pos_0  = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
-            /* .all_pos_1  = */ batch_all.all_pos_1,
-            /* .all_seq_id = */ batch_all.all_seq_id,
-        };
+        const uint32_t n_tokens = ubatch.n_tokens;
 
         // count the outputs in this u_batch
         {
             int32_t n_outputs_new = 0;
 
-            if (u_batch.logits && !embd_pooled) {
-                for (uint32_t i = 0; i < n_tokens; i++) {
-                    n_outputs_new += u_batch.logits[i] != 0;
-                }
-            } else if (n_outputs == n_tokens_all) {
+            if (n_outputs == n_tokens_all) {
                 n_outputs_new = n_tokens;
             } else {
-                // keep last output only
-                if (cur_token + n_tokens >= n_tokens_all) {
-                    n_outputs_new = 1;
+                GGML_ASSERT(ubatch.output);
+                for (uint32_t i = 0; i < n_tokens; i++) {
+                    n_outputs_new += (int32_t) (ubatch.output[i] != 0);
                 }
             }
 
@@ -14562,32 +15627,6 @@ static int llama_decode_internal(
         int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
         GGML_ASSERT(n_threads > 0);
 
-        // helpers for smoother batch API transition
-        // after deprecating the llama_eval calls, these will be removed
-        if (u_batch.pos == nullptr) {
-            pos.resize(n_tokens);
-            for (uint32_t i = 0; i < n_tokens; i++) {
-                pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1;
-            }
-
-            u_batch.pos = pos.data();
-        }
-
-        if (u_batch.seq_id == nullptr) {
-            n_seq_id.resize(n_tokens);
-            seq_id.resize(n_tokens);
-            seq_id_arr.resize(n_tokens);
-            for (uint32_t i = 0; i < n_tokens; i++) {
-                n_seq_id[i] = 1;
-                seq_id[i].resize(1);
-                seq_id[i][0] = u_batch.all_seq_id;
-                seq_id_arr[i] = seq_id[i].data();
-            }
-
-            u_batch.n_seq_id = n_seq_id.data();
-            u_batch.seq_id = seq_id_arr.data();
-        }
-
         // non-causal masks do not use the KV cache
         if (hparams.causal_attn) {
             llama_kv_cache_update(&lctx);
@@ -14598,7 +15637,7 @@ static int llama_decode_internal(
                 kv_self.head = 0;
             }
 
-            if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
+            if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
                 return 1;
             }
 
@@ -14617,7 +15656,7 @@ static int llama_decode_internal(
         ggml_backend_sched_reset(lctx.sched);
         ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
-        ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
+        ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
 
         // the output is always the last tensor in the graph
         struct ggml_tensor * res  = gf->nodes[gf->n_nodes - 1];
@@ -14628,12 +15667,15 @@ static int llama_decode_internal(
             res  = nullptr;
             embd = nullptr;
         } else if (cparams.embeddings) {
-            res = nullptr; // do not extract logits for embedding case
-            embd = gf->nodes[gf->n_nodes - 1];
-            if (strcmp(embd->name, "result_embd_pooled") != 0) {
-                embd = gf->nodes[gf->n_nodes - 2];
+            res  = nullptr; // do not extract logits for embedding case
+            embd = nullptr;
+            for (int i = gf->n_nodes - 1; i >= 0; --i) {
+                if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
+                    embd = gf->nodes[i];
+                    break;
+                }
             }
-            GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
+            GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
         } else {
             embd = nullptr; // do not extract embeddings when not needed
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
@@ -14642,7 +15684,7 @@ static int llama_decode_internal(
 
         ggml_backend_sched_alloc_graph(lctx.sched, gf);
 
-        llama_set_inputs(lctx, u_batch);
+        llama_set_inputs(lctx, ubatch);
 
         llama_graph_compute(lctx, gf, n_threads);
 
@@ -14700,12 +15742,11 @@ static int llama_decode_internal(
                 case LLAMA_POOLING_TYPE_CLS:
                 case LLAMA_POOLING_TYPE_LAST:
                     {
-                        // extract sequence embeddings
+                        // extract sequence embeddings (cleared before processing each batch)
                         auto & embd_seq_out = lctx.embd_seq;
-                        embd_seq_out.clear();
 
-                        for (uint32_t i = 0; i < n_tokens; i++) {
-                            const llama_seq_id seq_id = u_batch.seq_id[i][0];
+                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
                             if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
                                 continue;
                             }
@@ -14722,6 +15763,25 @@ static int llama_decode_internal(
         n_outputs_prev += lctx.n_outputs;
     }
 
+    // set output mappings
+    {
+        bool sorted_output = true;
+
+        GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
+
+        for (size_t i = 0; i < n_outputs; ++i) {
+            size_t out_id = lctx.sbatch.out_ids[i];
+            lctx.output_ids[out_id] = i;
+            if (out_id != i) {
+                sorted_output = false;
+            }
+        }
+
+        if (sorted_output) {
+            lctx.sbatch.out_ids.clear();
+        }
+    }
+
     // set to total number of outputs in the batch, for use in llama_get_logits_ith
     lctx.n_outputs = n_outputs;
 
@@ -14786,11 +15846,9 @@ static int llama_encode_internal(
 
     const int64_t n_embd = hparams.n_embd;
 
-    // TODO: simplify or deprecate
-    std::vector<llama_pos> pos;
-    std::vector<int32_t>                   n_seq_id;
-    std::vector<llama_seq_id *>            seq_id_arr;
-    std::vector<std::vector<llama_seq_id>> seq_id;
+    lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+
+    const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
 
     // reserve output buffer
     if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
@@ -14808,45 +15866,34 @@ static int llama_encode_internal(
     const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
     GGML_ASSERT(n_threads > 0);
 
-    // helpers for smoother batch API transition
-    // after deprecating the llama_eval calls, these will be removed
-    if (batch.pos == nullptr) {
-        pos.resize(n_tokens);
-        for (uint32_t i = 0; i < n_tokens; i++) {
-            pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
-        }
-
-        batch.pos = pos.data();
-    }
-
-    if (batch.seq_id == nullptr) {
-        n_seq_id.resize(n_tokens);
-        seq_id.resize(n_tokens);
-        seq_id_arr.resize(n_tokens);
-        for (uint32_t i = 0; i < n_tokens; i++) {
-            n_seq_id[i] = 1;
-            seq_id[i].resize(1);
-            seq_id[i][0] = batch.all_seq_id;
-            seq_id_arr[i] = seq_id[i].data();
-        }
-
-        batch.n_seq_id = n_seq_id.data();
-        batch.seq_id = seq_id_arr.data();
-    }
-
     ggml_backend_sched_reset(lctx.sched);
     ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
-    ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
+    ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
 
     // the output embeddings after the final encoder normalization
-    struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1];
+    struct ggml_tensor * embd = nullptr;
 
-    GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
+    // there are two cases here
+    if (llama_model_has_decoder(&lctx.model)) {
+        // first case is an encoder-decoder T5 model where embeddings are passed to decoder
+        embd = gf->nodes[gf->n_nodes - 1];
+        GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor");
+    } else {
+        // second case is an encoder-only T5 model
+        if (cparams.embeddings) {
+            // only output embeddings if required
+            embd = gf->nodes[gf->n_nodes - 1];
+            if (strcmp(embd->name, "result_embd_pooled") != 0) {
+                embd = gf->nodes[gf->n_nodes - 2];
+            }
+            GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
+        }
+    }
 
     ggml_backend_sched_alloc_graph(lctx.sched, gf);
 
-    llama_set_inputs(lctx, batch);
+    llama_set_inputs(lctx, ubatch);
 
     llama_graph_compute(lctx, gf, n_threads);
 
@@ -14855,20 +15902,57 @@ static int llama_encode_internal(
         ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
         GGML_ASSERT(backend_embd != nullptr);
 
-        // extract token embeddings
-        GGML_ASSERT(lctx.embd != nullptr);
+        if (llama_model_has_decoder(&lctx.model)) {
+            lctx.embd_enc.resize(n_tokens*n_embd);
+            float * embd_out = lctx.embd_enc.data();
+
+            ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
+            GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
 
-        lctx.embd_enc.resize(n_tokens*n_embd);
-        float * embd_out = lctx.embd_enc.data();
+            // remember the sequence ids used during the encoding - needed for cross attention later
+            lctx.seq_ids_enc.resize(n_tokens);
+            for (uint32_t i = 0; i < n_tokens; i++) {
+                for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
+                    llama_seq_id seq_id = ubatch.seq_id[i][s];
+                    lctx.seq_ids_enc[i].insert(seq_id);
+                }
+            }
+        } else {
+            GGML_ASSERT(lctx.embd != nullptr);
 
-        ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
+            switch (cparams.pooling_type) {
+                case LLAMA_POOLING_TYPE_NONE:
+                    {
+                        // extract token embeddings
+                        GGML_ASSERT(lctx.embd != nullptr);
+                        float * embd_out = lctx.embd;
 
-        // remember the sequence ids used during the encoding - needed for cross attention later
-        lctx.seq_ids_enc.resize(n_tokens);
-        for (uint32_t i = 0; i < n_tokens; i++) {
-            for (int s = 0; s < batch.n_seq_id[i]; s++) {
-                llama_seq_id seq_id = batch.seq_id[i][s];
-                lctx.seq_ids_enc[i].insert(seq_id);
+                        GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
+                        ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
+                    } break;
+                case LLAMA_POOLING_TYPE_MEAN:
+                case LLAMA_POOLING_TYPE_CLS:
+                case LLAMA_POOLING_TYPE_LAST:
+                    {
+                        // extract sequence embeddings
+                        auto & embd_seq_out = lctx.embd_seq;
+                        embd_seq_out.clear();
+
+                        GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
+
+                        for (uint32_t i = 0; i < n_tokens; i++) {
+                            const llama_seq_id seq_id = ubatch.seq_id[i][0];
+                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                                continue;
+                            }
+                            embd_seq_out[seq_id].resize(n_embd);
+                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
+                        }
+                    } break;
+                case LLAMA_POOLING_TYPE_UNSPECIFIED:
+                    {
+                        GGML_ABORT("unknown pooling type");
+                    }
             }
         }
     }
@@ -15135,32 +16219,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
         }
     }
 
-    if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
-        {
-            ggml_backend_sched_reset(lctx.sched);
-
-            ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
-
-            ggml_backend_sched_alloc_graph(lctx.sched, gf);
-
-            llama_set_s_copy(lctx);
-
-            llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
-
-            need_reserve = true;
-        }
-
-        {
-            auto & kv_self = lctx.kv_self;
-
-            kv_self.do_copy = false;
-
-            for (uint32_t i = 0; i < kv_self.size; ++i) {
-                kv_self.cells[i].src = i;
-            }
-        }
-    }
-
     // defragment the KV cache if needed
     if (lctx.kv_self.do_defrag) {
         llama_kv_cache_defrag_internal(lctx);
@@ -15174,10 +16232,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
     if (need_reserve) {
         // TODO: extract to a function
         // build worst-case graph
-        int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
-        int n_past = lctx.cparams.n_ctx - n_tokens;
+        uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+        uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
         llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-        ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
+        llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+        ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
 
         // initialize scheduler with the worst-case graph
         ggml_backend_sched_reset(lctx.sched);
@@ -15304,7 +16363,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
     const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
     auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
         if (n_expert > 1) {
-            // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but iccasionally randomly
+            // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
             // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
             // for getting the current layer as I initially thought, and we need to resort to parsing the
             // tensor name.
@@ -15569,6 +16628,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
             case GGML_TYPE_Q6_K:   new_type = GGML_TYPE_Q8_0;   break;
             default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
         }
+        if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
+            new_type = GGML_TYPE_F16;
+        }
         LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
         ++qs.n_fallback;
     }
@@ -15760,7 +16822,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
         // TODO: avoid hardcoded tensor names - use the TN_* constants
         if (name.find("attn_v.weight")   != std::string::npos ||
-            name.find("attn_qkv.weight") != std::string::npos) {
+            name.find("attn_qkv.weight") != std::string::npos ||
+            name.find("attn_kv_b.weight")!= std::string::npos) {
             ++qs.n_attention_wv;
         } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
             qs.has_output = true;
@@ -15770,12 +16833,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
 
     // sanity checks
-    //
-    //  - qs.n_attention_wv == 0                         for Mamba           models
-    //  - qs.n_attention_wv == model.hparams.n_layer     for Transformer     models
-    //  - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
-    //
-    GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
+    {
+        const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
+        // attention layers have a non-zero number of kv heads
+        int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
+        if (llama_model_has_encoder(&model)) {
+            n_attn_layer *= 3;
+        }
+        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
+    }
 
     size_t total_size_org = 0;
     size_t total_size_new = 0;
@@ -15897,8 +16963,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         // do not quantize Mamba's small yet 2D weights
         // NOTE: can't use LLM_TN here because the layer number is not known
         quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
-        quantize &= name.find("ssm_x.weight")      == std::string::npos;
-        quantize &= name.find("ssm_dt.weight")     == std::string::npos;
 
         // do not quantize relative position bias (T5)
         quantize &= name.find("attn_rel_b.weight") == std::string::npos;
@@ -16472,12 +17536,6 @@ struct llama_context * llama_new_context_with_model(
         params.flash_attn = false;
     }
 
-    if (params.flash_attn && model->hparams.attn_soft_cap) {
-        LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
-        params.flash_attn = false;
-    }
-
-
     if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
         LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
         params.flash_attn = false;
@@ -16578,13 +17636,15 @@ struct llama_context * llama_new_context_with_model(
 
     ctx->sampling.rng = std::mt19937(params.seed);
     ctx->logits_all   = params.logits_all;
+    // build worst-case graph for encoder if a model contains encoder
+    ctx->is_encoding  = llama_model_has_encoder(model);
 
     uint32_t kv_size = cparams.n_ctx;
     ggml_type type_k = params.type_k;
     ggml_type type_v = params.type_v;
 
     // Mamba only needs a constant number of KV cache cells per sequence
-    if (model->arch == LLM_ARCH_MAMBA) {
+    if (llama_model_is_recurrent(model)) {
         // Mamba needs at least as many KV cells as there are sequences kept at any time
         kv_size = std::max((uint32_t) 1, params.n_seq_max);
         // it's probably best to keep as much precision as possible for the states
@@ -16816,10 +17876,11 @@ struct llama_context * llama_new_context_with_model(
             }
 
             // build worst-case graph
-            int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
-            int n_past = cparams.n_ctx - n_tokens;
+            uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+            uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
             llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-            ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
+            llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+            ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
 
             // initialize scheduler with the worst-case graph
             if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
@@ -16892,6 +17953,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_MAMBA:
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_T5:
+        case LLM_ARCH_T5ENCODER:
         case LLM_ARCH_JAIS:
             return LLAMA_ROPE_TYPE_NONE;
 
@@ -16930,6 +17992,8 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX:
         case LLM_ARCH_CODESHELL:
+        case LLM_ARCH_NEMOTRON:
+        case LLM_ARCH_EXAONE:
             return LLAMA_ROPE_TYPE_NEOX;
 
         // all model arches should be listed explicitly here
@@ -17039,8 +18103,16 @@ struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const ch
 
 bool llama_model_has_encoder(const struct llama_model * model) {
     switch (model->arch) {
-        case LLM_ARCH_T5: return true;
-        default:          return false;
+        case LLM_ARCH_T5:        return true;
+        case LLM_ARCH_T5ENCODER: return true;
+        default:                 return false;
+    }
+}
+
+bool llama_model_has_decoder(const struct llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_T5ENCODER: return false;
+        default:                 return true;
     }
 }
 
@@ -17048,6 +18120,13 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
     return model->hparams.dec_start_token_id;
 }
 
+bool llama_model_is_recurrent(const struct llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_MAMBA:  return true;
+        default:              return false;
+    }
+}
+
 uint32_t llama_model_quantize(
         const char * fname_inp,
         const char * fname_out,
@@ -17343,6 +18422,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
 // TODO: replace all non-fatal assertions with returned errors or exceptions
 struct llama_data_write {
     virtual void write(const void * src, size_t size) = 0;
+    virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0;
     virtual size_t get_size_written() = 0;
     virtual ~llama_data_write() = default;
 
@@ -17368,7 +18448,9 @@ struct llama_data_write {
         write_string(rng_str);
     }
 
-    void write_output_ids(const struct llama_context * ctx) {
+    void write_output_ids(struct llama_context * ctx) {
+        llama_output_reorder(ctx);
+
         const uint32_t n_outputs = ctx->n_outputs;
 
         std::vector<int32_t> output_pos;
@@ -17465,9 +18547,8 @@ struct llama_data_write {
             // Read each range of cells of k_size length each into tmp_buf and write out
             for (const auto & range : cell_ranges) {
                 const size_t range_size = range.second - range.first;
-                tmp_buf.resize(range_size * k_size_row);
-                ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
-                write(tmp_buf.data(), tmp_buf.size());
+                const size_t buf_size = range_size * k_size_row;
+                write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
             }
         }
 
@@ -17486,9 +18567,8 @@ struct llama_data_write {
                 // Read each range of cells of v_size length each into tmp_buf and write out
                 for (const auto & range : cell_ranges) {
                     const size_t range_size = range.second - range.first;
-                    tmp_buf.resize(range_size * v_size_row);
-                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
-                    write(tmp_buf.data(), tmp_buf.size());
+                    const size_t buf_size = range_size * v_size_row;
+                    write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
                 }
             }
         } else {
@@ -17514,9 +18594,8 @@ struct llama_data_write {
                     for (const auto & range : cell_ranges) {
                         const size_t range_size = range.second - range.first;
                         const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                        tmp_buf.resize(range_size * v_size_el);
-                        ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
-                        write(tmp_buf.data(), tmp_buf.size());
+                        const size_t buf_size = range_size * v_size_el;
+                        write_tensor_data(kv_self.v_l[il], src_offset, buf_size);
                     }
                 }
             }
@@ -17659,8 +18738,11 @@ struct llama_data_read {
 
             llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
 
-            llama_batch batch = llama_batch_init(cell_count, 0, 1);
+            llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
             batch.n_tokens = cell_count;
+            batch.n_seq_tokens = cell_count;
+            batch.n_seqs = 1;
+
             for (uint32_t i = 0; i < cell_count; ++i) {
                 llama_pos pos;
                 uint32_t n_seq_id;
@@ -17674,11 +18756,10 @@ struct llama_data_read {
                 }
 
                 batch.pos[i] = pos;
-                batch.n_seq_id[i] = 1;
-                batch.seq_id[i][0] = dest_seq_id;
             }
+            batch.n_seq_id[0] = 1;
+            batch.seq_id[0] = &dest_seq_id;
             if (!llama_kv_cache_find_slot(kv_self, batch)) {
-                llama_batch_free(batch);
                 LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
                 return false;
             }
@@ -17690,9 +18771,6 @@ struct llama_data_read {
             GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
             GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
             GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
-
-            // Cleanup
-            llama_batch_free(batch);
         } else {
             // whole KV cache restore
 
@@ -17724,6 +18802,15 @@ struct llama_data_read {
                     }
 
                     cell.seq_id.insert(seq_id);
+
+                    if (kv_self.recurrent) {
+                        int32_t & tail = kv_self.cells[seq_id].tail;
+                        if (tail != -1) {
+                            LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
+                            return false;
+                        }
+                        tail = i;
+                    }
                 }
             }
 
@@ -17731,6 +18818,14 @@ struct llama_data_read {
             kv_self.used = cell_count;
         }
 
+        if (kv_self.recurrent) {
+            for (uint32_t i = 0; i < cell_count; ++i) {
+                uint32_t cell_id = kv_self.head + i;
+                // make sure the recurrent states will keep their restored state
+                kv_self.cells[cell_id].src = cell_id;
+            }
+        }
+
         return true;
     }
 
@@ -17875,12 +18970,14 @@ struct llama_data_write_dummy : llama_data_write {
 
     llama_data_write_dummy() {}
 
-    // TODO: avoid unnecessary calls to ggml_backend_tensor_get in a dummy context
-
     void write(const void * /* src */, size_t size) override {
         size_written += size;
     }
 
+    void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
+        size_written += size;
+    }
+
     size_t get_size_written() override {
         return size_written;
     }
@@ -17903,6 +19000,16 @@ struct llama_data_write_buffer : llama_data_write {
         buf_size -= size;
     }
 
+    void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
+        if (size > buf_size) {
+            throw std::runtime_error("unexpectedly reached end of buffer");
+        }
+        ggml_backend_tensor_get(tensor, ptr, offset, size);
+        ptr += size;
+        size_written += size;
+        buf_size -= size;
+    }
+
     size_t get_size_written() override {
         return size_written;
     }
@@ -17938,6 +19045,7 @@ struct llama_data_read_buffer : llama_data_read {
 struct llama_data_write_file : llama_data_write {
     llama_file * file;
     size_t size_written = 0;
+    std::vector<uint8_t> temp_buffer;
 
     llama_data_write_file(llama_file * f) : file(f) {}
 
@@ -17946,6 +19054,12 @@ struct llama_data_write_file : llama_data_write {
         size_written += size;
     }
 
+    void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
+        temp_buffer.resize(size);
+        ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
+        write(temp_buffer.data(), temp_buffer.size());
+    }
+
     size_t get_size_written() override {
         return size_written;
     }
@@ -18299,7 +19413,18 @@ struct llama_batch llama_batch_get_one(
 }
 
 struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
-    llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
+    llama_batch batch = {
+        /*n_tokens       =*/ 0,
+        /*tokens         =*/ nullptr,
+        /*embd           =*/ nullptr,
+        /*pos            =*/ nullptr,
+        /*n_seq_id       =*/ nullptr,
+        /*seq_id         =*/ nullptr,
+        /*logits         =*/ nullptr,
+        /*all_pos_0      =*/ 0,
+        /*all_pos_1      =*/ 0,
+        /*all_seq_id     =*/ 0,
+    };
 
     if (embd) {
         batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
@@ -18385,6 +19510,10 @@ void llama_synchronize(struct llama_context * ctx) {
 float * llama_get_logits(struct llama_context * ctx) {
     llama_synchronize(ctx);
 
+    // reorder logits for backward compatibility
+    // TODO: maybe deprecate this
+    llama_output_reorder(ctx);
+
     return ctx->logits;
 }
 
@@ -18429,6 +19558,10 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
 float * llama_get_embeddings(struct llama_context * ctx) {
     llama_synchronize(ctx);
 
+    // reorder embeddings for backward compatibility
+    // TODO: maybe deprecate this
+    llama_output_reorder(ctx);
+
     return ctx->embd;
 }
 
@@ -18530,11 +19663,11 @@ llama_token llama_token_pad(const struct llama_model * model) {
     return llama_token_pad_impl(model->vocab);
 }
 
-int32_t llama_add_bos_token(const struct llama_model * model) {
+bool llama_add_bos_token(const struct llama_model * model) {
     return llama_add_bos_token_impl(model->vocab);
 }
 
-int32_t llama_add_eos_token(const struct llama_model * model) {
+bool llama_add_eos_token(const struct llama_model * model) {
     return llama_add_eos_token_impl(model->vocab);
 }
 
@@ -18835,6 +19968,22 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "Assistant:";
         }
+    } else if (tmpl == "exaone3" || (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]"))) {
+        // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
+        // EXAONE-3.0-7.8B-Instruct
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
+            } else if (role == "user") {
+                ss << "[|user|]" << trim(message->content) << "\n";
+            } else if (role == "assistant") {
+                ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
+            }
+        }
+        if (add_ass) {
+            ss << "[|assistant|]";
+        }
     } else {
         // template not supported
         return -1;
index 66c266298e86f0204f6d466368abd8d86465a99c..6cca6320b347d860246d30d34f174b0f5a0affa6 100644 (file)
@@ -93,15 +93,15 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_TEKKEN         = 20,
         LLAMA_VOCAB_PRE_TYPE_SMOLLM         = 21,
         LLAMA_VOCAB_PRE_TYPE_CODESHELL      = 22,
+        LLAMA_VOCAB_PRE_TYPE_BLOOM          = 23,
+        LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH   = 24,
+        LLAMA_VOCAB_PRE_TYPE_EXAONE         = 25,
     };
 
-    // note: these values should be synchronized with ggml_rope
-    // TODO: maybe move this enum to ggml.h (ggml_rope_type)
     enum llama_rope_type {
         LLAMA_ROPE_TYPE_NONE = -1,
-        LLAMA_ROPE_TYPE_NORM =  0,
-        LLAMA_ROPE_TYPE_NEOX =  2,
-        LLAMA_ROPE_TYPE_GLM  =  4,
+        LLAMA_ROPE_TYPE_NORM = 0,
+        LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
     };
 
     enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
@@ -504,10 +504,16 @@ extern "C" {
     // Returns true if the model contains an encoder that requires llama_encode() call
     LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
 
+    // Returns true if the model contains a decoder that requires llama_decode() call
+    LLAMA_API bool llama_model_has_decoder(const struct llama_model * model);
+
     // For encoder-decoder models, this function returns id of the token that must be provided
     // to the decoder to start generating output sequence. For other models, it returns -1.
     LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
 
+    // Returns true if the model is recurrent (like Mamba, RWKV, etc.)
+    LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
+
     // Returns 0 on success
     LLAMA_API uint32_t llama_model_quantize(
             const char * fname_inp,
@@ -912,11 +918,8 @@ extern "C" {
     LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
     LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
 
-    // Returns -1 if unknown, 1 for true or 0 for false.
-    LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
-
-    // Returns -1 if unknown, 1 for true or 0 for false.
-    LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
+    LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
+    LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
 
     // Codellama infill tokens
     LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix