]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Mon, 28 Jul 2025 07:09:47 +0000 (10:09 +0300)
committerGeorgi Gerganov <redacted>
Mon, 28 Jul 2025 10:02:32 +0000 (13:02 +0300)
27 files changed:
examples/talk-llama/llama-arch.cpp
examples/talk-llama/llama-arch.h
examples/talk-llama/llama-batch.cpp
examples/talk-llama/llama-batch.h
examples/talk-llama/llama-chat.cpp
examples/talk-llama/llama-chat.h
examples/talk-llama/llama-context.cpp
examples/talk-llama/llama-context.h
examples/talk-llama/llama-cparams.h
examples/talk-llama/llama-graph.cpp
examples/talk-llama/llama-graph.h
examples/talk-llama/llama-hparams.cpp
examples/talk-llama/llama-hparams.h
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
examples/talk-llama/llama-kv-cache-unified-iswa.h
examples/talk-llama/llama-kv-cache-unified.cpp
examples/talk-llama/llama-kv-cache-unified.h
examples/talk-llama/llama-memory-hybrid.cpp
examples/talk-llama/llama-memory-recurrent.cpp
examples/talk-llama/llama-model.cpp
examples/talk-llama/llama-model.h
examples/talk-llama/llama-quant.cpp
examples/talk-llama/llama-vocab.cpp
examples/talk-llama/llama-vocab.h
examples/talk-llama/llama.h
examples/talk-llama/unicode.cpp
examples/talk-llama/unicode.h

index e63ab284bc3b59bca499add2f387bc119ec39294..062a99776781f6146783b4670fe14df9ebde32d0 100644 (file)
@@ -34,6 +34,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_PHI3,             "phi3"             },
     { LLM_ARCH_PHIMOE,           "phimoe"           },
     { LLM_ARCH_PLAMO,            "plamo"            },
+    { LLM_ARCH_PLAMO2,           "plamo2"           },
     { LLM_ARCH_CODESHELL,        "codeshell"        },
     { LLM_ARCH_ORION,            "orion"            },
     { LLM_ARCH_INTERNLM2,        "internlm2"        },
@@ -67,6 +68,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_JAIS,             "jais"             },
     { LLM_ARCH_NEMOTRON,         "nemotron"         },
     { LLM_ARCH_EXAONE,           "exaone"           },
+    { LLM_ARCH_EXAONE4,          "exaone4"          },
     { LLM_ARCH_RWKV6,            "rwkv6"            },
     { LLM_ARCH_RWKV6QWEN2,       "rwkv6qwen2"       },
     { LLM_ARCH_RWKV7,            "rwkv7"            },
@@ -81,9 +83,11 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_DOTS1,            "dots1"            },
     { LLM_ARCH_ARCEE,            "arcee"            },
     { LLM_ARCH_ERNIE4_5,         "ernie4_5"         },
+    { LLM_ARCH_ERNIE4_5_MOE,     "ernie4_5-moe"     },
     { LLM_ARCH_HUNYUAN_MOE,      "hunyuan-moe"      },
     { LLM_ARCH_SMOLLM3,          "smollm3"          },
     { LLM_ARCH_LFM2,             "lfm2"             },
+    { LLM_ARCH_DREAM,            "dream"            },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -784,6 +788,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_PLAMO2,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_SSM_IN,          "blk.%d.ssm_in" },
+            { LLM_TENSOR_SSM_CONV1D,      "blk.%d.ssm_conv1d" },
+            { LLM_TENSOR_SSM_X,           "blk.%d.ssm_x" },
+            { LLM_TENSOR_SSM_DT,          "blk.%d.ssm_dt" },
+            { LLM_TENSOR_SSM_A,           "blk.%d.ssm_a" },
+            { LLM_TENSOR_SSM_D,           "blk.%d.ssm_d" },
+            { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" },
+            { LLM_TENSOR_SSM_DT_NORM,     "blk.%d.ssm_dt_norm" },
+            { LLM_TENSOR_SSM_B_NORM,      "blk.%d.ssm_b_norm" },
+            { LLM_TENSOR_SSM_C_NORM,      "blk.%d.ssm_c_norm" },
+            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
+        },
+    },
     {
         LLM_ARCH_CODESHELL,
         {
@@ -1477,6 +1511,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_EXAONE4,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
+        }
+    },
     {
         LLM_ARCH_RWKV6,
         {
@@ -1793,6 +1847,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_ERNIE4_5_MOE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
+            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
+            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_EXP_PROBS_B,    "blk.%d.exp_probs_b" },
+        },
+    },
     {
         LLM_ARCH_HUNYUAN_MOE,
         {
@@ -1854,6 +1933,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_TOKEN_EMBD_NORM,   "token_embd_norm" },
         }
     },
+    {
+        LLM_ARCH_DREAM,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -2094,6 +2190,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
     switch (arch) {
         case LLM_ARCH_JAMBA:
         case LLM_ARCH_FALCON_H1:
+        case LLM_ARCH_PLAMO2:
         case LLM_ARCH_GRANITE_HYBRID:
         case LLM_ARCH_LFM2:
             return true;
@@ -2101,3 +2198,12 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
             return false;
     }
 }
+
+bool llm_arch_is_diffusion(const llm_arch & arch) {
+    switch (arch) {
+        case LLM_ARCH_DREAM:
+            return true;
+        default:
+            return false;
+    }
+}
index 1f97325952411182d875af8e52a1416f95324016..d09b7d7810b03a2cebb5abc463ca9744805fa100 100644 (file)
@@ -38,6 +38,7 @@ enum llm_arch {
     LLM_ARCH_PHI3,
     LLM_ARCH_PHIMOE,
     LLM_ARCH_PLAMO,
+    LLM_ARCH_PLAMO2,
     LLM_ARCH_CODESHELL,
     LLM_ARCH_ORION,
     LLM_ARCH_INTERNLM2,
@@ -71,6 +72,7 @@ enum llm_arch {
     LLM_ARCH_JAIS,
     LLM_ARCH_NEMOTRON,
     LLM_ARCH_EXAONE,
+    LLM_ARCH_EXAONE4,
     LLM_ARCH_RWKV6,
     LLM_ARCH_RWKV6QWEN2,
     LLM_ARCH_RWKV7,
@@ -85,9 +87,11 @@ enum llm_arch {
     LLM_ARCH_DOTS1,
     LLM_ARCH_ARCEE,
     LLM_ARCH_ERNIE4_5,
+    LLM_ARCH_ERNIE4_5_MOE,
     LLM_ARCH_HUNYUAN_MOE,
     LLM_ARCH_SMOLLM3,
     LLM_ARCH_LFM2,
+    LLM_ARCH_DREAM,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -478,3 +482,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
 
 bool llm_arch_is_recurrent(const llm_arch & arch);
 bool llm_arch_is_hybrid   (const llm_arch & arch);
+bool llm_arch_is_diffusion(const llm_arch & arch);
index 3bc8554e51ccf518e781ba5076780ae757c294a9..a546063c0a7c8c64808c1e87a8db4b3f14dfab1e 100644 (file)
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
         const llama_vocab & vocab,
         const llama_memory_i * memory,
         uint32_t n_embd,
+        uint32_t n_seq_max,
         bool output_all) {
     clear();
 
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
     // validate input batch
     //
 
+    if (n_seq_max > LLAMA_MAX_SEQ) {
+        LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
+        return false;
+    }
+
     if (batch.token) {
         for (int32_t i = 0; i < batch.n_tokens; ++i) {
             if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
     if (batch.seq_id) {
         for (int32_t i = 0; i < batch.n_tokens; ++i) {
             for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
-                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
-                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
+                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
                     return false;
                 }
             }
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
 
         // initialize the starting position for each sequence based on the positions in the memory
         llama_pos p0[LLAMA_MAX_SEQ];
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             if (!memory) {
                 // if no memory -> start from 0
                 p0[s] = 0;
@@ -143,13 +149,16 @@ bool llama_batch_allocr::init(
     // compute stats
     //
 
-    this->n_embd = n_embd;
+    this->n_embd    = n_embd;
+    this->n_seq_max = n_seq_max;
 
     // count the outputs in this batch
     for (int32_t i = 0; i < batch.n_tokens; ++i) {
         n_outputs += batch.logits[i] != 0;
     }
 
+    has_cpl = false;
+
     // determine coupled sequences
     // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
     for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -189,7 +198,7 @@ bool llama_batch_allocr::init(
             seq_set_map[cur].push_back(i);
         }
 
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             if (seq_set_unq.test(s)) {
                 seq_idx[s] = seq_id_unq.size();
                 seq_id_unq.push_back(s);
@@ -201,7 +210,7 @@ bool llama_batch_allocr::init(
         LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
 
         llama_ubatch ubatch {
-            /*.equal_seqs   =*/ false,
+            /*.b_equal_seqs =*/ false,
             /*.n_tokens     =*/ (uint32_t) batch.n_tokens,
             /*.n_seq_tokens =*/ (uint32_t) 1,
             /*.n_seqs       =*/ (uint32_t) batch.n_tokens,
@@ -214,6 +223,7 @@ bool llama_batch_allocr::init(
             /*.seq_id_unq   =*/ this->seq_id_unq.data(),
             /*.seq_idx      =*/ this->seq_idx.data(),
             /*.output       =*/ batch.logits,
+            /*.data         =*/ {},
         };
 
         ubatch_print(ubatch, debug);
@@ -241,7 +251,7 @@ bool llama_batch_allocr::init(
     // consistency checks
     //
 
-    for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+    for (uint32_t s = 0; s < n_seq_max; ++s) {
         if (seq_pos[s].empty()) {
             continue;
         }
@@ -284,8 +294,8 @@ bool llama_batch_allocr::init(
     }
 
     if (memory) {
-        for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
-            for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
+        for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
+            for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
                 if (seq_cpl[s0][s1]) {
                     if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
                         memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@@ -316,12 +326,12 @@ bool llama_batch_allocr::init(
     //
     {
         seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             cur_seq_set[s].set();
         }
 
         llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
-        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        for (uint32_t s = 0; s < n_seq_max; ++s) {
             cur_seq_pos[s] = -1;
         }
 
@@ -357,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
     clear();
     split_reset();
 
-    ubatches.emplace_back();
+    auto udata = std::make_shared<llama_ubatch::data_t>();
 
-    auto & ubatch = ubatches.back();
-
-    ubatch.token     .resize(n_tokens);
-    ubatch.embd      .clear();
-    ubatch.pos       .resize(n_tokens);
-    ubatch.n_seq_id  .resize(n_tokens);
-    ubatch.seq_id    .resize(n_tokens);
-    ubatch.seq_id_unq.resize(0);
-    ubatch.seq_idx   .resize(LLAMA_MAX_SEQ, -1);
-    ubatch.output    .resize(n_tokens);
+    udata->token     .resize(n_tokens);
+    udata->embd      .clear();
+    udata->pos       .resize(n_tokens);
+    udata->n_seq_id  .resize(n_tokens);
+    udata->seq_id    .resize(n_tokens);
+    udata->seq_id_unq.resize(0);
+    udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
+    udata->output    .resize(n_tokens);
 
     for (uint32_t s = 0; s < n_seqs; ++s) {
-        ubatch.seq_idx[s] = s;
-        ubatch.seq_id_unq.push_back(s);
+        udata->seq_idx[s] = s;
+        udata->seq_id_unq.push_back(s);
     }
 
     llama_ubatch res {
-        /*.equal_seqs   =*/ true,
+        /*.b_equal_seqs =*/ true,
         /*.n_tokens     =*/ n_tokens,
         /*.n_seq_tokens =*/ n_seq_tokens,
         /*.n_seqs       =*/ n_seqs,
         /*.n_seqs_unq   =*/ n_seqs,
 
-        /*.token        =*/ ubatch.token.data(),
+        /*.token        =*/ udata->token.data(),
         /*.embd         =*/ nullptr,
-        /*.pos          =*/ ubatch.pos.data(),
-        /*.n_seq_id     =*/ ubatch.n_seq_id.data(),
-        /*.seq_id       =*/ ubatch.seq_id.data(),
-        /*.seq_id_unq   =*/ ubatch.seq_id_unq.data(),
-        /*.seq_idx      =*/ ubatch.seq_idx.data(),
-        /*.output       =*/ ubatch.output.data(),
+        /*.pos          =*/ udata->pos.data(),
+        /*.n_seq_id     =*/ udata->n_seq_id.data(),
+        /*.seq_id       =*/ udata->seq_id.data(),
+        /*.seq_id_unq   =*/ udata->seq_id_unq.data(),
+        /*.seq_idx      =*/ udata->seq_idx.data(),
+        /*.output       =*/ udata->output.data(),
+        /*.data         =*/ std::move(udata),
     };
 
     return res;
@@ -430,8 +439,6 @@ void llama_batch_allocr::split_reset() {
 
     used.clear();
     used.resize(get_n_tokens(), false);
-
-    ubatches.clear();
 }
 
 llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
@@ -646,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
 
     assert(n_tokens%n_seqs == 0);
 
-    ubatches.emplace_back();
-
-    auto & ubatch = ubatches.back();
+    auto udata = std::make_shared<llama_ubatch::data_t>();
 
     const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
 
     const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
     const int64_t n_pos_all  =              (int64_t) n_tokens*n_pos_cur;
 
-    ubatch.token     .resize(n_tokens);
-    ubatch.embd      .resize(n_embd_all);
-    ubatch.pos       .resize(n_pos_all);
-    ubatch.n_seq_id  .resize(n_tokens);
-    ubatch.seq_id    .resize(n_tokens);
-    ubatch.seq_id_unq.resize(0);
-    ubatch.seq_idx   .resize(LLAMA_MAX_SEQ, -1);
-    ubatch.output    .resize(n_tokens);
+    udata->token     .resize(n_tokens);
+    udata->embd      .resize(n_embd_all);
+    udata->pos       .resize(n_pos_all);
+    udata->n_seq_id  .resize(n_tokens);
+    udata->seq_id    .resize(n_tokens);
+    udata->seq_id_unq.resize(0);
+    udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
+    udata->output    .resize(n_tokens);
 
     seq_set_t seq_set_unq;
 
     for (size_t i = 0; i < idxs.size(); ++i) {
         if (batch.token) {
-            ubatch.token[i] = batch.token[idxs[i]];
+            udata->token[i] = batch.token[idxs[i]];
         }
 
         if (batch.embd) {
-            memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
+            memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
         }
 
         for (int j = 0; j < n_pos_cur; ++j) {
-            ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
+            udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
         }
 
-        ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
-        ubatch.seq_id[i]   = batch.seq_id[idxs[i]];
-        ubatch.output[i]   = batch.logits[idxs[i]];
+        udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
+        udata->seq_id[i]   = batch.seq_id[idxs[i]];
+        udata->output[i]   = batch.logits[idxs[i]];
 
-        for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
-            seq_set_unq.set(ubatch.seq_id[i][s]);
+        for (int s = 0; s < udata->n_seq_id[i]; ++s) {
+            seq_set_unq.set(udata->seq_id[i][s]);
         }
 
-        if (ubatch.output[i]) {
+        if (udata->output[i]) {
             out_ids.push_back(idxs[i]);
         }
     }
 
-    for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+    for (uint32_t s = 0; s < n_seq_max; ++s) {
         if (seq_set_unq.test(s)) {
-            ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
-            ubatch.seq_id_unq.push_back(s);
+            udata->seq_idx[s] = udata->seq_id_unq.size();
+            udata->seq_id_unq.push_back(s);
         }
     }
 
     llama_ubatch res {
-        /*.equal_seqs   =*/ equal_seqs,
+        /*.b_equal_seqs =*/ equal_seqs,
         /*.n_tokens     =*/ n_tokens,
         /*.n_seq_tokens =*/ n_tokens/n_seqs,
         /*.n_seqs       =*/ n_seqs,
-        /*.n_seqs_unq   =*/ (uint32_t) ubatch.seq_id_unq.size(),
-
-        /*.token        =*/ batch.token ? ubatch.token.data() : nullptr,
-        /*.embd         =*/ batch.embd ? ubatch.embd.data() : nullptr,
-        /*.pos          =*/ ubatch.pos.data(),
-        /*.n_seq_id     =*/ ubatch.n_seq_id.data(),
-        /*.seq_id       =*/ ubatch.seq_id.data(),
-        /*.seq_id_unq   =*/ ubatch.seq_id_unq.data(),
-        /*.seq_idx      =*/ ubatch.seq_idx.data(),
-        /*.output       =*/ ubatch.output.data(),
+        /*.n_seqs_unq   =*/ (uint32_t) udata->seq_id_unq.size(),
+
+        /*.token        =*/ batch.token ? udata->token.data() : nullptr,
+        /*.embd         =*/ batch.embd ? udata->embd.data() : nullptr,
+        /*.pos          =*/ udata->pos.data(),
+        /*.n_seq_id     =*/ udata->n_seq_id.data(),
+        /*.seq_id       =*/ udata->seq_id.data(),
+        /*.seq_id_unq   =*/ udata->seq_id_unq.data(),
+        /*.seq_idx      =*/ udata->seq_idx.data(),
+        /*.output       =*/ udata->output.data(),
+        /*.data         =*/ std::move(udata),
     };
 
     if (debug > 0) {
-        LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
+        LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
 
         ubatch_print(res, debug);
     }
@@ -727,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
 
 void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
     if (debug > 0) {
-        LLAMA_LOG_DEBUG("%s:   equal_seqs   = %d\n", __func__, ubatch.equal_seqs);
+        LLAMA_LOG_DEBUG("%s:   equal_seqs   = %d\n", __func__, ubatch.equal_seqs());
         LLAMA_LOG_DEBUG("%s:   n_tokens     = %d\n", __func__, ubatch.n_tokens);
         LLAMA_LOG_DEBUG("%s:   n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
         LLAMA_LOG_DEBUG("%s:   n_seqs       = %d\n", __func__, ubatch.n_seqs);
index 3420803ff946967319e96947497b315ae74d1f6c..d563adc66aaf561eb037f9650cb7c07574cac3a7 100644 (file)
@@ -8,12 +8,17 @@
 #include <vector>
 #include <set>
 #include <bitset>
+#include <memory>
 #include <unordered_map>
 
 // keep this struct lightweight
-// it points to data in `llama_batch_allocr`
 struct llama_ubatch {
-    bool equal_seqs;
+    bool equal_seqs() const {
+        return b_equal_seqs != 0;
+    }
+
+    uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
+                           //       otherwise address sanitizer complains
     // TODO: whole_seqs for embeddings?
 
     uint32_t n_tokens;     // total tokens (n_seq_tokens * n_seqs)
@@ -34,6 +39,20 @@ struct llama_ubatch {
     llama_seq_id *  seq_id_unq; // [n_seqs_unq]       | s   | seq_id
     int32_t      *  seq_idx;    // [LLAMA_MAX_SEQ]    | -   | seq_idx
     int8_t       *  output;     // [n_tokens]         | i   | -
+
+    struct data_t {
+        std::vector<llama_token>    token;
+        std::vector<float>          embd;
+        std::vector<llama_pos>      pos;
+        std::vector<int32_t>        n_seq_id;
+        std::vector<llama_seq_id *> seq_id;
+        std::vector<llama_seq_id>   seq_id_unq;
+        std::vector<int32_t>        seq_idx;
+        std::vector<int8_t>         output;
+    };
+
+    // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
+    std::shared_ptr<data_t> data;
 };
 
 // a helper for sanitizing, fulfilling and splitting a batch
@@ -48,6 +67,7 @@ public:
             const llama_vocab & vocab,
             const llama_memory_i * memory,
             uint32_t n_embd,
+            uint32_t n_seq_max,
             bool output_all);
 
     const llama_batch & get_batch() const;
@@ -100,6 +120,7 @@ private:
     const uint32_t n_pos_per_embd;
 
     uint32_t n_embd;
+    uint32_t n_seq_max;
     uint32_t n_outputs;
 
     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@@ -115,7 +136,7 @@ private:
     using seq_cpl_t = std::vector<bool>;
 
     // helper flag to quickly determine if there are any coupled sequences in the batch
-    bool has_cpl;
+    bool has_cpl = false;
 
     std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
     std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
@@ -135,20 +156,5 @@ private:
     // used[i] indicates if token i has already been used in a previous ubatch
     std::vector<bool> used;
 
-    // llama_ubatch points to this data:
-    struct ubatch {
-        std::vector<llama_token>    token;
-        std::vector<float>          embd;
-        std::vector<llama_pos>      pos;
-        std::vector<int32_t>        n_seq_id;
-        std::vector<llama_seq_id *> seq_id;
-        std::vector<llama_seq_id>   seq_id_unq;
-        std::vector<int32_t>        seq_idx;
-        std::vector<int8_t>         output;
-    };
-
-    // current splitting state:
-    std::vector<ubatch> ubatches;
-
     int debug;
 };
index cbc19d3c40c30e2cec9ca96b2631ce6aab4fdd84..d34bb26878c2a971cc978cc08592974ad14b4586 100644 (file)
@@ -56,6 +56,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "glmedge",           LLM_CHAT_TEMPLATE_GLMEDGE           },
     { "minicpm",           LLM_CHAT_TEMPLATE_MINICPM           },
     { "exaone3",           LLM_CHAT_TEMPLATE_EXAONE_3          },
+    { "exaone4",           LLM_CHAT_TEMPLATE_EXAONE_4          },
     { "rwkv-world",        LLM_CHAT_TEMPLATE_RWKV_WORLD        },
     { "granite",           LLM_CHAT_TEMPLATE_GRANITE           },
     { "gigachat",          LLM_CHAT_TEMPLATE_GIGACHAT          },
@@ -65,6 +66,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "llama4",            LLM_CHAT_TEMPLATE_LLAMA4            },
     { "smolvlm",           LLM_CHAT_TEMPLATE_SMOLVLM           },
     { "hunyuan-moe",       LLM_CHAT_TEMPLATE_HUNYUAN_MOE       },
+    { "kimi-k2",           LLM_CHAT_TEMPLATE_KIMI_K2           },
 };
 
 llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -167,10 +169,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
     } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
         return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
     } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
+        if (tmpl_contains("[|tool|]")) {
+            return LLM_CHAT_TEMPLATE_EXAONE_4;
+        }
         // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
         // EXAONE-3.0-7.8B-Instruct
         return LLM_CHAT_TEMPLATE_EXAONE_3;
-    } else if (tmpl_contains("rwkv-world")) {
+    } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
         return LLM_CHAT_TEMPLATE_RWKV_WORLD;
     } else if (tmpl_contains("<|start_of_role|>")) {
         return LLM_CHAT_TEMPLATE_GRANITE;
@@ -188,6 +193,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_DOTS1;
     } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
         return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
+    } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
+        return LLM_CHAT_TEMPLATE_KIMI_K2;
     }
     return LLM_CHAT_TEMPLATE_UNKNOWN;
 }
@@ -529,6 +536,22 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "[|assistant|]";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
+            } else if (role == "user") {
+                ss << "[|user|]" << trim(message->content) << "\n";
+            } else if (role == "assistant") {
+                ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
+            } else if (role == "tool") {
+                ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
+            }
+        }
+        if (add_ass) {
+            ss << "[|assistant|]";
+        }
     } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
         // this template requires the model to have "\n\n" as EOT token
         for (size_t i = 0; i < chat.size(); i++) {
@@ -680,6 +703,25 @@ int32_t llm_chat_apply_template(
                 ss << "<|startoftext|>" << message->content << "<|extra_0|>";
             }
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
+        // moonshotai/Kimi-K2-Instruct
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "<|im_system|>system<|im_middle|>";
+            } else if (role == "user") {
+                ss << "<|im_user|>user<|im_middle|>";
+            } else if (role == "assistant") {
+                ss << "<|im_assistant|>assistant<|im_middle|>";
+            } else if (role == "tool") {
+                ss << "<|im_system|>tool<|im_middle|>";
+            }
+
+            ss << message->content << "<|im_end|>";
+        }
+        if (add_ass) {
+            ss << "<|im_assistant|>assistant<|im_middle|>";
+        }
     } else {
         // template not supported
         return -1;
index b621fda281669897f2cf604ba6aba277a7722b6c..6968a19fbe13c8de2c2135df1d49ce039ce509eb 100644 (file)
@@ -35,6 +35,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_GLMEDGE,
     LLM_CHAT_TEMPLATE_MINICPM,
     LLM_CHAT_TEMPLATE_EXAONE_3,
+    LLM_CHAT_TEMPLATE_EXAONE_4,
     LLM_CHAT_TEMPLATE_RWKV_WORLD,
     LLM_CHAT_TEMPLATE_GRANITE,
     LLM_CHAT_TEMPLATE_GIGACHAT,
@@ -45,6 +46,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_SMOLVLM,
     LLM_CHAT_TEMPLATE_DOTS1,
     LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
+    LLM_CHAT_TEMPLATE_KIMI_K2,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 
index 06e93b19cbf4087b24284ea3a473f12441c31b6a..9e77fe6d869599255729b6ed0e908becf8be390d 100644 (file)
@@ -98,10 +98,20 @@ llama_context::llama_context(
         LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
         cparams.n_batch = GGML_KQ_MASK_PAD;
     }
-
     cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
 
     cparams.op_offload = params.op_offload;
+    cparams.kv_unified = params.kv_unified;
+
+    {
+        const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
+        supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
+
+        if (!supports_set_rows && !cparams.kv_unified) {
+            LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
+            cparams.kv_unified = true;
+        }
+    }
 
     const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
 
@@ -112,6 +122,7 @@ llama_context::llama_context(
     LLAMA_LOG_INFO("%s: n_ubatch      = %u\n",   __func__, cparams.n_ubatch);
     LLAMA_LOG_INFO("%s: causal_attn   = %d\n",   __func__, cparams.causal_attn);
     LLAMA_LOG_INFO("%s: flash_attn    = %d\n",   __func__, cparams.flash_attn);
+    LLAMA_LOG_INFO("%s: kv_unified    = %s\n",   __func__, cparams.kv_unified ? "true" : "false");
     LLAMA_LOG_INFO("%s: freq_base     = %.1f\n", __func__, cparams.rope_freq_base);
     LLAMA_LOG_INFO("%s: freq_scale    = %g\n",   __func__, cparams.rope_freq_scale);
 
@@ -227,8 +238,8 @@ llama_context::llama_context(
 
         LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
 
-        // buffer used to store the computation graph and the tensor meta data
-        buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
+        gf_res_prev.reset(new llm_graph_result(max_nodes));
+        gf_res_reserve.reset(new llm_graph_result(max_nodes));
 
         // TODO: move these checks to ggml_backend_sched
         // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -267,7 +278,7 @@ llama_context::llama_context(
 
     // reserve worst-case graph
     if (!hparams.vocab_only && memory) {
-        const uint32_t n_seqs = cparams.n_seq_max;
+        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
         LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -287,7 +298,7 @@ llama_context::llama_context(
 
         cross.v_embd.clear();
 
-        // reserve pp graph first so that buffers are only allocated once
+        // reserve pp (prompt processing) graph first so that buffers are only allocated once
         {
             auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
             if (!gf) {
@@ -298,9 +309,9 @@ llama_context::llama_context(
             n_nodes_pp  = ggml_graph_n_nodes(gf);
         }
 
-        // reserve with tg graph to get the number of splits and nodes
+        // reserve with tg (token generation) graph to get the number of splits and nodes
         {
-            auto * gf = graph_reserve(1, 1, 1, mctx.get());
+            auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute tg buffers");
             }
@@ -311,6 +322,10 @@ llama_context::llama_context(
 
         // reserve again with pp graph to avoid ggml-alloc reallocations during inference
         {
+            // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
+            //
+            // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
+            //
             auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
@@ -388,10 +403,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
     return sched.get();
 }
 
-ggml_context * llama_context::get_ctx_compute() const {
-    return ctx_compute.get();
-}
-
 uint32_t llama_context::n_ctx() const {
     return cparams.n_ctx;
 }
@@ -463,6 +474,11 @@ bool llama_context::kv_self_update(bool optimize) {
                 }
         }
 
+        // reset the previous graph result to make sure that it won't be reused
+        // TODO: change the mctx->apply() to return information if a graph reserve is needed
+        //       reset the graph result only if the memory module did reset the scheduler
+        gf_res_prev->reset();
+
         if (!mctx->apply()) {
             LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
         }
@@ -475,7 +491,7 @@ bool llama_context::kv_self_update(bool optimize) {
             throw std::runtime_error("failed to initialize memory context");
         }
 
-        const uint32_t n_seqs   = cparams.n_seq_max;
+        const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
         auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -492,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
 }
 
 float * llama_context::get_logits() {
+    output_reorder();
+
     return logits;
 }
 
 float * llama_context::get_logits_ith(int32_t i) {
     int64_t j = -1;
 
+    output_reorder();
+
     try {
         if (logits == nullptr) {
             throw std::runtime_error("no logits");
@@ -534,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
 }
 
 float * llama_context::get_embeddings() {
+    output_reorder();
+
     return embd;
 }
 
 float * llama_context::get_embeddings_ith(int32_t i) {
     int64_t j = -1;
 
+    output_reorder();
+
     try {
         if (embd == nullptr) {
             throw std::runtime_error("no embeddings");
@@ -678,38 +702,59 @@ bool llama_context::apply_adapter_cvec(
     return cvec.apply(model, data, len, n_embd, il_start, il_end);
 }
 
-llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
+llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
     if (mctx && !mctx->apply()) {
         LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
         ret = GGML_STATUS_FAILED;
         return nullptr;
     }
 
-    auto * gf = graph_init();
-    if (!gf) {
-        LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
-        ret = GGML_STATUS_FAILED;
-        return nullptr;
-    }
+    auto * res = gf_res_prev.get();
+    auto * gf  = res->get_gf();
 
-    auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
-    if (!res) {
-        LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
-        ret = GGML_STATUS_FAILED;
-        return nullptr;
-    }
+    // the new graph parameters
+    // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
+    const auto gparams = graph_params(res, ubatch, mctx, gtype);
 
-    // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
+    if (res->can_reuse(gparams)) {
+        //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
 
-    if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
-        LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
-        ret = GGML_STATUS_ALLOC_FAILED;
-        return nullptr;
+        n_reused++;
+    } else {
+        res->reset();
+
+        ggml_backend_sched_reset(sched.get());
+        ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
+
+        //const auto t_start_us = ggml_time_us();
+
+        gf = model.build_graph(gparams);
+
+        //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
+
+        if (!gf) {
+            LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
+            ret = GGML_STATUS_FAILED;
+            return nullptr;
+        }
+
+        if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
+            LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
+            ret = GGML_STATUS_ALLOC_FAILED;
+            return nullptr;
+        }
     }
 
-    res->set_inputs(&ubatch);
+    // set the input data for the input tensors
+    {
+        //const auto t_start_us = ggml_time_us();
+
+        res->set_inputs(&ubatch);
+
+        //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
+    }
 
-    const auto status = graph_compute(gf, ubatch.n_tokens > 1);
+    const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
     if (status != GGML_STATUS_SUCCESS) {
         LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
         ret = status;
@@ -731,16 +776,19 @@ int llama_context::encode(const llama_batch & batch_inp) {
 
     const auto & hparams = model.hparams;
 
-    const int64_t n_embd = hparams.n_embd;
+    const int64_t n_embd  = hparams.n_embd;
+    const int32_t n_vocab = model.vocab.n_tokens();
 
     // note: during encode, we always pass the full sequence starting from pos = 0
-    if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
+    if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
 
     const uint32_t n_tokens = balloc->get_n_tokens();
 
+    // [TAG_NO_CACHE_PAD]
+    // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
     const llama_ubatch ubatch = balloc->split_simple(n_tokens);
 
     // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
@@ -767,9 +815,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
 
     n_outputs = n_tokens;
 
-    ggml_backend_sched_reset(sched.get());
-    ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
-
     const auto causal_attn_org = cparams.causal_attn;
 
     // always use non-causal attention for encoder graphs
@@ -778,7 +823,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
     cparams.causal_attn = false;
 
     ggml_status status;
-    const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
+    const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
 
     cparams.causal_attn = causal_attn_org;
 
@@ -791,10 +836,20 @@ int llama_context::encode(const llama_batch & batch_inp) {
         }
     }
 
+    auto * t_logits = res->get_logits();
     auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
 
+    // extract logits
+   if (logits && t_logits) {
+        ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
+        GGML_ASSERT(backend_res != nullptr);
+        GGML_ASSERT(logits != nullptr);
+
+        ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
+    }
+
     // extract embeddings
-    if (t_embd) {
+    if (embd && t_embd) {
         ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
         GGML_ASSERT(backend_embd != nullptr);
 
@@ -844,9 +899,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
         }
     }
 
-    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-    // overlap with device computation.
-    ggml_backend_sched_reset(sched.get());
+    if (!supports_set_rows) {
+        // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
+        // overlap with device computation.
+        ggml_backend_sched_reset(sched.get());
+    }
 
     // TODO: hacky solution
     if (model.arch == LLM_ARCH_T5 && t_embd) {
@@ -899,7 +956,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     // when computing embeddings, all tokens are output
     const bool output_all = cparams.embeddings;
 
-    if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
+    if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
@@ -927,6 +984,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
 
     // TODO: this clear of the buffer can easily be forgotten - need something better
     embd_seq.clear();
+    output_swaps.clear();
 
     bool did_optimize = false;
 
@@ -1005,11 +1063,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
             n_outputs = n_outputs_new;
         }
 
-        ggml_backend_sched_reset(sched.get());
-        ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
-
         ggml_status status;
-        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
+        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
 
         if (!res) {
             // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1149,9 +1204,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
         // make the outputs have the same order they had in the user-provided batch
         // note: this is mostly relevant for recurrent models atm
         if (!sorted_output) {
-            const uint32_t n_vocab = model.vocab.n_tokens();
-            const uint64_t n_embd  = model.hparams.n_embd;
-
             GGML_ASSERT((size_t) n_outputs == out_ids.size());
 
             // TODO: is there something more efficient which also minimizes swaps?
@@ -1167,16 +1219,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
                     continue;
                 }
                 std::swap(out_ids[i], out_ids[j_min]);
-                if (logits_size > 0) {
-                    for (uint32_t k = 0; k < n_vocab; k++) {
-                        std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
-                    }
-                }
-                if (embd_size > 0) {
-                    for (uint32_t k = 0; k < n_embd; k++) {
-                        std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
-                    }
-                }
+
+                // remember the swaps and apply them lazily upon logits/embeddings access
+                output_swaps.push_back({ i, j_min });
             }
 
             std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1190,9 +1235,11 @@ int llama_context::decode(const llama_batch & batch_inp) {
     // wait for the computation to finish (automatically done when obtaining the model output)
     //synchronize();
 
-    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-    // overlap with device computation.
-    ggml_backend_sched_reset(sched.get());
+    if (!supports_set_rows) {
+        // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
+        // overlap with device computation.
+        ggml_backend_sched_reset(sched.get());
+    }
 
     return 0;
 }
@@ -1271,24 +1318,40 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
     return n_outputs_max;
 }
 
+void llama_context::output_reorder() {
+    const uint32_t n_vocab = model.vocab.n_tokens();
+    const uint64_t n_embd  = model.hparams.n_embd;
+
+    for (uint32_t s = 0; s < output_swaps.size(); ++s) {
+        const uint32_t i0 = output_swaps[s].i0;
+        const uint32_t i1 = output_swaps[s].i1;
+
+        if (logits_size > 0) {
+            for (uint32_t k = 0; k < n_vocab; k++) {
+                std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
+            }
+        }
+
+        if (embd_size > 0) {
+            for (uint32_t k = 0; k < n_embd; k++) {
+                std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
+            }
+        }
+    }
+
+    output_swaps.clear();
+}
+
 //
 // graph
 //
 
-int32_t llama_context::graph_max_nodes() const {
-    return std::max<int32_t>(65536, 5*model.n_tensors());
+uint32_t llama_context::graph_max_nodes() const {
+    return std::max<uint32_t>(1024u, 8u*model.n_tensors());
 }
 
-ggml_cgraph * llama_context::graph_init() {
-    ggml_init_params params = {
-        /*.mem_size   =*/ buf_compute_meta.size(),
-        /*.mem_buffer =*/ buf_compute_meta.data(),
-        /*.no_alloc   =*/ true,
-    };
-
-    ctx_compute.reset(ggml_init(params));
-
-    return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
+llm_graph_result * llama_context::get_gf_res_reserve() const {
+    return static_cast<llm_graph_result *>(gf_res_reserve.get());
 }
 
 ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
@@ -1301,6 +1364,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
         LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
     }
 
+    ggml_backend_sched_reset(sched.get());
+
+    // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
+    gf_res_prev->reset();
+
     // store the n_outputs as it is, and restore it afterwards
     // TODO: not sure if needed, might simplify in the future by removing this
     const auto save_n_outputs = this->n_outputs;
@@ -1310,17 +1378,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
     llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
     llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
 
-    auto * gf = graph_init();
-    auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
+    auto * res = gf_res_reserve.get();
 
-    this->n_outputs = save_n_outputs;
+    const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
 
-    if (!res) {
-        LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
-        return nullptr;
-    }
+    res->reset();
 
-    ggml_backend_sched_reset(sched.get());
+    auto * gf = model.build_graph(gparams);
+
+    this->n_outputs = save_n_outputs;
 
     // initialize scheduler with the specified graph
     if (!ggml_backend_sched_reserve(sched.get(), gf)) {
@@ -1331,28 +1397,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
     return gf;
 }
 
-llm_graph_result_ptr llama_context::graph_build(
-                      ggml_context * ctx,
-                       ggml_cgraph * gf,
-                const llama_ubatch & ubatch,
-                    llm_graph_type   gtype,
-      const llama_memory_context_i * mctx) {
-    return model.build_graph(
-            {
-                /*.ctx         =*/ ctx,
-                /*.arch        =*/ model.arch,
-                /*.hparams     =*/ model.hparams,
-                /*.cparams     =*/ cparams,
-                /*.ubatch      =*/ ubatch,
-                /*.sched       =*/ sched.get(),
-                /*.backend_cpu =*/ backend_cpu,
-                /*.cvec        =*/ &cvec,
-                /*.loras       =*/ &loras,
-                /*.mctx        =*/ mctx,
-                /*.cross       =*/ &cross,
-                /*.n_outputs   =*/ n_outputs,
-                /*.cb          =*/ graph_get_cb(),
-            }, gf, gtype);
+llm_graph_params llama_context::graph_params(
+                        llm_graph_result * res,
+                      const llama_ubatch & ubatch,
+            const llama_memory_context_i * mctx,
+            llm_graph_type   gtype) const {
+    return {
+        /*.arch        =*/ model.arch,
+        /*.hparams     =*/ model.hparams,
+        /*.cparams     =*/ cparams,
+        /*.ubatch      =*/ ubatch,
+        /*.gtype       =*/ gtype,
+        /*.sched       =*/ sched.get(),
+        /*.backend_cpu =*/ backend_cpu,
+        /*.cvec        =*/ &cvec,
+        /*.loras       =*/ &loras,
+        /*.mctx        =*/ mctx,
+        /*.cross       =*/ &cross,
+        /*.n_outputs   =*/ n_outputs,
+        /*.cb          =*/ graph_get_cb(),
+        /*.res         =*/ res,
+    };
 }
 
 ggml_status llama_context::graph_compute(
@@ -1930,6 +1995,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
     data.t_eval_ms   = 1e-3 * t_eval_us;
     data.n_p_eval    = std::max(1, n_p_eval);
     data.n_eval      = std::max(1, n_eval);
+    data.n_reused    = std::max(0, n_reused);
 
     return data;
 }
@@ -1938,6 +2004,7 @@ void llama_context::perf_reset() {
     t_start_us  = ggml_time_us();
     t_eval_us   = n_eval = 0;
     t_p_eval_us = n_p_eval = 0;
+    n_reused    = 0;
 }
 
 //
@@ -2028,7 +2095,7 @@ void llama_context::opt_epoch_iter(
             batch.logits  [pos_batch]    = true;
         }
 
-        if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
+        if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
             LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
             return;
         }
@@ -2064,8 +2131,13 @@ void llama_context::opt_epoch_iter(
                 break;
             }
 
-            auto * gf = graph_init();
-            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
+            auto * res = gf_res_prev.get();
+
+            const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
+
+            res->reset();
+
+            auto * gf = model.build_graph(gparams);
 
             struct ggml_context * ctx_compute_opt;
             {
@@ -2187,6 +2259,7 @@ llama_context_params llama_context_default_params() {
         /*.no_perf                     =*/ true,
         /*.op_offload                  =*/ true,
         /*.swa_full                    =*/ true,
+        /*.kv_unified                  =*/ false,
     };
 
     return result;
@@ -2807,6 +2880,7 @@ void llama_perf_context_print(const llama_context * ctx) {
     LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
             __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
     LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
+    LLAMA_LOG_INFO("%s:    graphs reused = %10d\n", __func__, data.n_reused);
 }
 
 void llama_perf_context_reset(llama_context * ctx) {
index 9ce05715a8c0306312bd03f516ae38f1f79cb2f6..5c3a1c09886ea29178b9427f43abc3c085a7e9f7 100644 (file)
@@ -35,8 +35,6 @@ struct llama_context {
 
     ggml_backend_sched_t get_sched() const;
 
-    ggml_context * get_ctx_compute() const;
-
     uint32_t n_ctx()         const;
     uint32_t n_ctx_per_seq() const;
     uint32_t n_batch()       const;
@@ -96,7 +94,7 @@ struct llama_context {
     // if memory_context is provided, it will be applied first to the context's memory
     // ret contains the status of the graph computation
     // returns nullptr only if ret != GGML_STATUS_SUCCESS
-    llm_graph_result_ptr process_ubatch(
+    llm_graph_result * process_ubatch(
                 const llama_ubatch & ubatch,
                     llm_graph_type   gtype,
             llama_memory_context_i * mctx,
@@ -183,15 +181,17 @@ private:
     // Returns max number of outputs for which space was reserved.
     uint32_t output_reserve(int32_t n_outputs);
 
+    void output_reorder();
+
     //
     // graph
     //
 
 public:
-    int32_t graph_max_nodes() const;
+    uint32_t graph_max_nodes() const;
 
-    // zero-out inputs and create the ctx_compute for the compute graph
-    ggml_cgraph * graph_init();
+    // can reuse the llm_graph_result instance of the context (for example to update a memory module)
+    llm_graph_result * get_gf_res_reserve() const;
 
     // returns the result of ggml_backend_sched_graph_compute_async execution
     ggml_status graph_compute(ggml_cgraph * gf, bool batched);
@@ -200,12 +200,11 @@ public:
     ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
 
 private:
-    llm_graph_result_ptr graph_build(
-                      ggml_context * ctx,
-                       ggml_cgraph * gf,
-                const llama_ubatch & ubatch,
-                    llm_graph_type   gtype,
-      const llama_memory_context_i * mctx);
+    llm_graph_params graph_params(
+                        llm_graph_result * res,
+                      const llama_ubatch & ubatch,
+            const llama_memory_context_i * mctx,
+                          llm_graph_type   gtype) const;
 
     llm_graph_cb graph_get_cb() const;
 
@@ -253,13 +252,18 @@ private:
 
     std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
 
+    struct swap_info {
+        uint32_t i0;
+        uint32_t i1;
+    };
+
+    std::vector<swap_info> output_swaps;
+
     ggml_backend_sched_ptr sched;
 
     ggml_backend_t backend_cpu = nullptr;
     std::vector<ggml_backend_ptr> backends;
 
-    ggml_context_ptr ctx_compute;
-
     // training
     ggml_opt_context_t opt_ctx = nullptr;
 
@@ -275,14 +279,18 @@ private:
     std::vector<ggml_backend_t>             backend_ptrs;
     std::vector<ggml_backend_buffer_type_t> backend_buft;
 
-    // memory buffers used to evaluate the model
-    std::vector<uint8_t> buf_compute_meta;
+    llm_graph_result_ptr gf_res_prev;
+    llm_graph_result_ptr gf_res_reserve;
 
     // host buffer for the model output (logits and embeddings)
     ggml_backend_buffer_ptr buf_output;
 
     bool has_evaluated_once = false;
 
+    // env: LLAMA_SET_ROWS (temporary)
+    // ref: https://github.com/ggml-org/llama.cpp/pull/14285
+    bool supports_set_rows = false;
+
     // perf
     mutable int64_t t_start_us  = 0;
     mutable int64_t t_load_us   = 0;
@@ -294,4 +302,6 @@ private:
 
     mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
     mutable int32_t n_eval   = 0; // number of eval calls
+
+    mutable int32_t n_reused = 0; // number of times the previous graph was reused
 };
index 118615d5bd2d59f4e640c661592cf91ee0d3fda3..38750affc500b74504f53bf65320332ddcef0d09 100644 (file)
@@ -11,8 +11,8 @@ struct llama_cparams {
     uint32_t n_batch;
     uint32_t n_ubatch;
     uint32_t n_seq_max;
-    int      n_threads;       // number of threads to use for generation
-    int      n_threads_batch; // number of threads to use for batch processing
+    int32_t  n_threads;       // number of threads to use for generation
+    int32_t  n_threads_batch; // number of threads to use for batch processing
 
     float rope_freq_base;
     float rope_freq_scale;
@@ -33,6 +33,7 @@ struct llama_cparams {
     bool no_perf;
     bool warmup;
     bool op_offload;
+    bool kv_unified;
 
     enum llama_pooling_type pooling_type;
 
index a248a7ec22350898ac2f31ba9d817e528651b925..b63a41053b488b23025b73df495b1f2f2e8c43d3 100644 (file)
@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
+    bool res = true;
+
+    res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
+    res &= (!embd   && !params.ubatch.embd)  || (embd   &&   embd->ne[0] == params.ubatch.n_tokens);
+
+    return res;
+}
+
 void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
     if (ubatch->pos && pos) {
         const int64_t n_tokens = ubatch->n_tokens;
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
+    bool res = true;
+
+    res &= pos->ne[0] == params.ubatch.n_tokens;
+
+    return res;
+}
+
 void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
     if (ubatch->pos && attn_scale) {
         const int64_t n_tokens = ubatch->n_tokens;
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
         const int64_t n_tokens = ubatch->n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
-        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+        GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
 
         int32_t * data = (int32_t *) pos_bucket->data;
 
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
+    bool res = true;
+
+    res &= n_outputs == params.n_outputs;
+
+    return res;
+}
+
 void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
         const int64_t n_tokens     = ubatch->n_tokens;
@@ -287,6 +312,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
     mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
 }
 
+bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
+
+    this->mctx = mctx;
+
+    bool res = true;
+
+    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
+  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+    res &= self_kq_mask->ne[0] == mctx->get_n_kv();
+    res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
+
+    res &= mctx->get_supports_set_rows(); // TODO: tmp
+
+    return res;
+}
+
 void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
     mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
     mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -299,6 +342,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
     mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
 }
 
+bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
+
+    this->mctx = mctx;
+
+    bool res = true;
+
+    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
+  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+    res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
+  //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+    res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
+    res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
+
+    res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
+    res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
+
+    res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
+
+    return res;
+}
+
 void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
     GGML_ASSERT(cross_kq_mask);
 
@@ -306,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
     const int64_t n_tokens = ubatch->n_tokens;
 
     GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
-    GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
 
     float * data = (float *) cross_kq_mask->data;
 
@@ -340,6 +407,91 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
     inp_rs->set_input(ubatch);
 }
 
+//
+// llm_graph_result
+//
+
+llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
+    reset();
+
+    const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
+    debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
+}
+
+int64_t llm_graph_result::get_max_nodes() const {
+    return max_nodes;
+}
+
+void llm_graph_result::reset() {
+    t_tokens      = nullptr;
+    t_logits      = nullptr;
+    t_embd        = nullptr;
+    t_embd_pooled = nullptr;
+
+    params = {};
+
+    inputs.clear();
+
+    buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
+
+    ggml_init_params params = {
+        /*.mem_size   =*/ buf_compute_meta.size(),
+        /*.mem_buffer =*/ buf_compute_meta.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    ctx_compute.reset(ggml_init(params));
+
+    gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
+}
+
+void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
+    for (auto & input : inputs) {
+        input->set_input(ubatch);
+    }
+}
+
+bool llm_graph_result::can_reuse(const llm_graph_params & params) {
+    if (!this->params.allow_reuse(params)) {
+        if (debug > 1) {
+            LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
+        }
+
+        return false;
+    }
+
+    if (debug > 1) {
+        LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
+    }
+
+    bool res = true;
+
+    for (auto & input : inputs) {
+        const bool cur = input->can_reuse(params);
+
+        if (debug > 1) {
+            LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
+        }
+
+        res = res && cur;
+    }
+
+    if (debug > 0) {
+        LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
+    }
+
+    return res;
+}
+
+llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
+    inputs.emplace_back(std::move(input));
+    return inputs.back().get();
+}
+
+void llm_graph_result::set_params(const llm_graph_params & params) {
+    this->params = params;
+}
+
 //
 // llm_graph_context
 //
@@ -374,7 +526,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     n_ctx_orig       (cparams.n_ctx_orig_yarn),
     pooling_type     (cparams.pooling_type),
     rope_type        (hparams.rope_type),
-    ctx0             (params.ctx),
     sched            (params.sched),
     backend_cpu      (params.backend_cpu),
     cvec             (params.cvec),
@@ -382,7 +533,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     mctx             (params.mctx),
     cross            (params.cross),
     cb_func          (params.cb),
-    res              (std::make_unique<llm_graph_result>()) {
+    res              (params.res),
+    ctx0             (res->get_ctx()),
+    gf               (res->get_gf()) {
+        res->set_params(params);
     }
 
 void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -753,20 +907,28 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
         cb(cur, "ffn_moe_weighted", il);
     }
 
+    ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
+
+    assert(n_expert_used > 0);
+
+    // order the views before the adds
+    for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
+        cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
+
+        ggml_build_forward_expand(gf, cur_experts[i]);
+    }
+
     // aggregate experts
-    ggml_tensor * moe_out = nullptr;
-    for (int i = 0; i < n_expert_used; ++i) {
-        ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
-                experts->nb[2], i*experts->nb[1]);
+    // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
+    //       to avoid potentially a large number of add nodes during warmup
+    //       ref: https://github.com/ggml-org/llama.cpp/pull/14753
+    ggml_tensor * moe_out = cur_experts[0];
 
-        if (i == 0) {
-            moe_out = cur_expert;
-        } else {
-            moe_out = ggml_add(ctx0, moe_out, cur_expert);
-        }
+    for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
+        moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
     }
 
-    if (n_expert_used == 1) {
+    if (hparams.n_expert_used == 1) {
         // avoid returning a non-contiguous tensor
         moe_out = ggml_cont(ctx0, moe_out);
     }
@@ -972,7 +1134,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
 }
 
 ggml_tensor * llm_graph_context::build_attn_mha(
-         ggml_cgraph * gf,
          ggml_tensor * q,
          ggml_tensor * k,
          ggml_tensor * v,
@@ -982,13 +1143,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
              float     kq_scale) const {
     const bool v_trans = v->nb[1] > v->nb[2];
 
+    // split the batch into streams if needed
+    const auto n_stream = k->ne[3];
+
+    q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
+
     q = ggml_permute(ctx0, q, 0, 2, 1, 3);
     k = ggml_permute(ctx0, k, 0, 2, 1, 3);
     v = ggml_permute(ctx0, v, 0, 2, 1, 3);
 
-    const auto n_tokens = q->ne[1];
-    const auto n_head   = q->ne[2];
-    const auto n_kv     = k->ne[1];
+    const auto n_kv = k->ne[1];
 
     ggml_tensor * cur;
 
@@ -1030,7 +1194,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 #endif
         }
 
-        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
     } else {
         ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
 
@@ -1075,7 +1239,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 
         cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
 
-        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
+        // recombine streams
+        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
 
         if (!cparams.offload_kqv) {
             // all nodes between the KV store and the attention output are run on the CPU
@@ -1102,7 +1267,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
 
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_no_cache * inp,
-        ggml_cgraph * gf,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1122,11 +1286,15 @@ ggml_tensor * llm_graph_context::build_attn(
 
     const auto & kq_mask = inp->get_kq_mask();
 
+    // [TAG_NO_CACHE_PAD]
+    // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
+    assert(!ubatch.equal_seqs());
+
     ggml_tensor * q = q_cur;
     ggml_tensor * k = k_cur;
     ggml_tensor * v = v_cur;
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1156,13 +1324,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
     {
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 
-        const auto n_kv = mctx_cur->get_n_kv();
+        const auto n_kv     = mctx_cur->get_n_kv();
         const auto n_tokens = ubatch.n_tokens;
+        const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
 
         inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1181,7 +1350,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
 
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_kv_unified * inp,
-        ggml_cgraph * gf,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1214,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = mctx_cur->get_k(ctx0, il);
     ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1234,7 +1402,6 @@ ggml_tensor * llm_graph_context::build_attn(
 
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_kv_unified_iswa * inp,
-        ggml_cgraph * gf,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1281,7 +1448,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = mctx_cur->get_k(ctx0, il);
     ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1314,7 +1481,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
 
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_cross * inp,
-        ggml_cgraph * gf,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1336,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = k_cur;
     ggml_tensor * v = v_cur;
 
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1362,13 +1528,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
 
     auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
 
+    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
+
     {
         const auto n_kv = mctx_cur->get_base()->get_n_kv();
 
         inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1382,7 +1550,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
         inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
+        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
         ggml_set_input(inp->self_kq_mask_swa);
 
         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1392,7 +1560,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
 }
 
 ggml_tensor * llm_graph_context::build_rs(
-        ggml_cgraph * gf,
         ggml_tensor * s,
         ggml_tensor * state_copy,
             int32_t   state_size,
@@ -1450,21 +1617,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
 
 ggml_tensor * llm_graph_context::build_rs(
         llm_graph_input_rs * inp,
-        ggml_cgraph * gf,
         ggml_tensor * s,
             int32_t   state_size,
             int32_t   n_seqs,
         const llm_graph_get_rows_fn & get_state_rows) const {
     const auto * kv_state = inp->mctx;
 
-    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
+    return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
 }
 
 ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
     llm_graph_input_rs * inp,
-           ggml_cgraph * gf,
     const llama_ubatch & ubatch,
-                 int   il) const {
+                   int   il) const {
     const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
     const auto token_shift_count = hparams.token_shift_count;
@@ -1474,7 +1639,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
     ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
 
     ggml_tensor * token_shift = build_rs(
-            inp, gf, token_shift_all,
+            inp, token_shift_all,
             hparams.n_embd_r(), n_seqs);
 
     token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1514,7 +1679,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
 }
 
 void llm_graph_context::build_pooling(
-        ggml_cgraph * gf,
         ggml_tensor * cls,
         ggml_tensor * cls_b,
         ggml_tensor * cls_out,
index fbf8e2889564ddb591af4bad939857305737f36a..a28a8c4bddad838514e0fe301f982ab0d4b9890c 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "llama-arch.h"
+#include "llama-batch.h"
 #include "llama-hparams.h"
 #include "llama-adapter.h"
 
@@ -14,7 +15,6 @@ struct ggml_cgraph;
 struct ggml_context;
 struct ggml_tensor;
 
-struct llama_ubatch;
 struct llama_cparams;
 
 struct llama_memory_context_i;
@@ -69,6 +69,8 @@ struct llama_cross {
     std::vector<std::set<llama_seq_id>> seq_ids_enc;
 };
 
+struct llm_graph_params;
+
 //
 // llm_graph_input
 //
@@ -78,11 +80,19 @@ public:
     virtual ~llm_graph_input_i() = default;
 
     virtual void set_input(const llama_ubatch * ubatch) = 0;
+
+    // return true if the resulting input tensors using the provided graph parameters would be
+    //   the same as the previous input tensors that we have currently stored in the object
+    virtual bool can_reuse(const llm_graph_params & params) {
+        // returning false here by default will prevent from reusing the graph if the check
+        //   for the input type has not been implemented yet
+        GGML_UNUSED(params);
+        return false;
+    }
 };
 
 using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
 
-
 class llm_graph_input_embd : public llm_graph_input_i {
 public:
     llm_graph_input_embd()          = default;
@@ -90,6 +100,8 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * tokens = nullptr; // I32 [n_batch]
     ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
 };
@@ -101,6 +113,8 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * pos = nullptr; // I32 [n_batch]
 
     const uint32_t n_pos_per_embd = 1;
@@ -154,17 +168,19 @@ public:
     llm_graph_input_out_ids(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
+            uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
     virtual ~llm_graph_input_out_ids() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * out_ids; // I32 [n_outputs]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const int32_t n_outputs;
+    const uint32_t n_outputs;
 };
 
 class llm_graph_input_mean : public llm_graph_input_i {
@@ -249,16 +265,18 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * get_k_idxs() const { return self_k_idxs; }
     ggml_tensor * get_v_idxs() const { return self_v_idxs; }
 
     ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
 
     ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
-    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
 
-    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;
@@ -280,6 +298,8 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    bool can_reuse(const llm_graph_params & params) override;
+
     ggml_tensor * get_k_idxs()     const { return self_k_idxs; }
     ggml_tensor * get_v_idxs()     const { return self_v_idxs; }
     ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
@@ -289,14 +309,14 @@ public:
     ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
 
     ggml_tensor * self_k_idxs     = nullptr; // I64 [n_batch]
-    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs     = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
     ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
-    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
+    ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
 
-    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch, 1, 1]
-    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;
@@ -351,40 +371,108 @@ public:
 // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
 //   these are used by the llama_context to extact the relevant data, based on the compute parameters
 
-class llm_graph_result_i {
-public:
-    virtual ~llm_graph_result_i() = default;
+// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
+using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
 
-    virtual ggml_tensor * get_tokens()      = 0;
-    virtual ggml_tensor * get_logits()      = 0;
-    virtual ggml_tensor * get_embd()        = 0;
-    virtual ggml_tensor * get_embd_pooled() = 0;
+class llm_graph_result;
 
-    virtual void set_inputs(const llama_ubatch * ubatch) = 0;
-};
+struct llm_graph_params {
+    llm_arch arch = LLM_ARCH_UNKNOWN;
 
-using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
+    llama_hparams hparams;
+    llama_cparams cparams;
 
+    llama_ubatch ubatch; // note: intentionally make a copy
 
-class llm_graph_result : public llm_graph_result_i {
-public:
-    virtual ~llm_graph_result() = default;
+    llm_graph_type gtype;
 
-    ggml_tensor * get_tokens()      override { return t_tokens; }
-    ggml_tensor * get_logits()      override { return t_logits; }
-    ggml_tensor * get_embd()        override { return t_embd; }
-    ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
+    ggml_backend_sched_t sched;
+    ggml_backend_t backend_cpu;
 
-    void set_inputs(const llama_ubatch * ubatch) override {
-        for (auto & input : inputs) {
-            input->set_input(ubatch);
+    const llama_adapter_cvec     * cvec;
+    const llama_adapter_loras    * loras;
+    const llama_memory_context_i * mctx;
+    const llama_cross            * cross;
+
+    uint32_t n_outputs;
+
+    llm_graph_cb cb;
+
+    llm_graph_result * res;
+
+    // return true if the "other" params would result in a graph with the same topology as with the current params
+    //   having the same topology allows us to reuse the graph in some cases
+    bool allow_reuse(const llm_graph_params & other) const {
+        // first check the ubatch
+        bool can_reuse_ubatch =
+            ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
+            ubatch.n_tokens     == other.ubatch.n_tokens &&
+            ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
+            ubatch.n_seqs       == other.ubatch.n_seqs &&
+            ubatch.n_seqs_unq   == other.ubatch.n_seqs_unq &&
+            (
+                (!ubatch.token && !other.ubatch.token) ||
+                (!ubatch.embd  && !other.ubatch.embd)
+            );
+
+        if (can_reuse_ubatch && !ubatch.equal_seqs()) {
+            if (!ubatch.data) {
+                // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
+                //   therefore we cannot perform the sequence id check. normally should never happen
+                can_reuse_ubatch = false;
+            } else {
+                for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+                    can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
+                }
+            }
         }
-    }
 
-    llm_graph_input_i * add_input(llm_graph_input_ptr input) {
-        inputs.emplace_back(std::move(input));
-        return inputs.back().get();
+        if (!can_reuse_ubatch) {
+            return false;
+        }
+
+        return
+            cparams.embeddings  == other.cparams.embeddings  &&
+            cparams.causal_attn == other.cparams.causal_attn &&
+            arch      == other.arch  &&
+            gtype     == other.gtype &&
+            cvec      == other.cvec  &&
+            loras     == other.loras &&
+            cross     == other.cross &&
+            n_outputs == other.n_outputs;
     }
+};
+
+class llm_graph_result {
+public:
+    llm_graph_result(int64_t max_nodes);
+
+    virtual ~llm_graph_result() = default;
+
+    ggml_tensor * get_tokens()      const { return t_tokens; }
+    ggml_tensor * get_logits()      const { return t_logits; }
+    ggml_tensor * get_embd()        const { return t_embd; }
+    ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
+
+    ggml_cgraph  * get_gf()  const { return gf; }
+    ggml_context * get_ctx() const { return ctx_compute.get(); }
+
+    int64_t get_max_nodes() const;
+
+    void reset();
+
+    void set_inputs(const llama_ubatch * ubatch);
+
+    // try to update the existing graph result using the new graph parameters in order to reuse it
+    // this can only be done if we determine that the resulting graph using the new graph parameters
+    //   would be identical to the existing graph. in that case, we simply have to update the memory
+    //   contexts of the input tensors of the graph and we can reuse it for another computation
+    // return true if the graph was updated and can be reused
+    bool can_reuse(const llm_graph_params & params);
+
+    llm_graph_input_i * add_input(llm_graph_input_ptr input);
+
+    void set_params(const llm_graph_params & params);
 
     // important graph nodes
     ggml_tensor * t_tokens      = nullptr;
@@ -393,36 +481,31 @@ public:
     ggml_tensor * t_embd_pooled = nullptr;
 
     std::vector<llm_graph_input_ptr> inputs;
-};
 
-//
-// llm_graph_context
-//
+    ggml_context_ptr ctx_compute;
 
-// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
-using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
+    // memory buffers used to evaluate the model
+    std::vector<uint8_t> buf_compute_meta;
 
-struct llm_graph_params {
-    ggml_context * ctx;
+    ggml_cgraph * gf;
 
-    const llm_arch arch;
+    int64_t max_nodes;
 
-    const llama_hparams & hparams;
-    const llama_cparams & cparams;
-    const llama_ubatch  & ubatch;
+private:
+    // keep a copy of the previous graph parameters
+    // we will use this to determine whether the graph can be reused by comparing them with the new parameters
+    // note: these are updated after constructing the new graph
+    llm_graph_params params;
 
-    ggml_backend_sched_t sched;
-    ggml_backend_t backend_cpu;
-
-    const llama_adapter_cvec     * cvec;
-    const llama_adapter_loras    * loras;
-    const llama_memory_context_i * mctx;
-    const llama_cross            * cross;
+    // env: LLAMA_GRAPH_RESULT_DEBUG
+    int debug = 0;
+};
 
-    uint32_t n_outputs;
+using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
 
-    const llm_graph_cb & cb;
-};
+//
+// llm_graph_context
+//
 
 // used in build_rs to properly order writes and avoid unnecessary copies
 using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
@@ -463,8 +546,6 @@ struct llm_graph_context {
     const enum llama_pooling_type pooling_type;
     const enum llama_rope_type    rope_type;
 
-    ggml_context * ctx0 = nullptr;
-
     ggml_backend_sched_t sched;
 
     ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
@@ -476,7 +557,10 @@ struct llm_graph_context {
 
     const llm_graph_cb & cb_func;
 
-    std::unique_ptr<llm_graph_result> res;
+    llm_graph_result * res;
+
+    ggml_context * ctx0 = nullptr;
+    ggml_cgraph  * gf   = nullptr;
 
     llm_graph_context(const llm_graph_params & params);
     virtual ~llm_graph_context() = default;
@@ -562,7 +646,6 @@ struct llm_graph_context {
     //
 
     ggml_tensor * build_attn_mha(
-             ggml_cgraph * gf,
              ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
              ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
              ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -575,7 +658,6 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn(
             llm_graph_input_attn_no_cache * inp,
-            ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -590,7 +672,6 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn(
             llm_graph_input_attn_kv_unified * inp,
-            ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -606,7 +687,6 @@ struct llm_graph_context {
     // note: if k_cur or v_cur are not provided, they will not be stored in the memory
     ggml_tensor * build_attn(
             llm_graph_input_attn_kv_unified_iswa * inp,
-            ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -621,7 +701,6 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn(
             llm_graph_input_attn_cross * inp,
-            ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -643,7 +722,6 @@ struct llm_graph_context {
     //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
     //         `llama_memory_recurrent`
     ggml_tensor * build_rs(
-            ggml_cgraph * gf,
             ggml_tensor * s,
             ggml_tensor * state_copy,
                 int32_t   state_size,
@@ -658,7 +736,6 @@ struct llm_graph_context {
 
     ggml_tensor * build_rs(
             llm_graph_input_rs * inp,
-            ggml_cgraph * gf,
             ggml_tensor * s,
                 int32_t   state_size,
                 int32_t   n_seqs,
@@ -666,9 +743,8 @@ struct llm_graph_context {
 
     ggml_tensor * build_rwkv_token_shift_load(
         llm_graph_input_rs * inp,
-               ggml_cgraph * gf,
         const llama_ubatch & ubatch,
-                     int   il) const;
+                       int   il) const;
 
     ggml_tensor * build_rwkv_token_shift_store(
              ggml_tensor * token_shift,
@@ -685,7 +761,6 @@ struct llm_graph_context {
     //
 
     void build_pooling(
-            ggml_cgraph * gf,
             ggml_tensor * cls,
             ggml_tensor * cls_b,
             ggml_tensor * cls_out,
index 7aa736e2f39db9ca876dde6a791daff8bd1e5bea..c6c67d26f9392cd4c81e50160e52690ec96b7f10 100644 (file)
@@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
     return n_embd_head_v * n_head_kv;
 }
 
+bool llama_hparams::is_n_embd_k_gqa_variable() const {
+    const uint32_t val = n_embd_k_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        if (val != n_embd_k_gqa(il)) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+bool llama_hparams::is_n_embd_v_gqa_variable() const {
+    const uint32_t val = n_embd_v_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        if (val != n_embd_v_gqa(il)) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+uint32_t llama_hparams::n_embd_k_gqa_max() const {
+    uint32_t val = n_embd_k_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        val = std::max(val, n_embd_k_gqa(il));
+    }
+
+    return val;
+}
+
+uint32_t llama_hparams::n_embd_v_gqa_max() const {
+    uint32_t val = n_embd_v_gqa();
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        val = std::max(val, n_embd_v_gqa(il));
+    }
+
+    return val;
+}
+
 uint32_t llama_hparams::n_embd_r() const {
     if (wkv_head_size != 0) {
         // for RWKV models
index d0500e4d0fd7732a0b1f51766cecd48c2c25c67e..ec7fd6a42bf54d2006da127e375b200046c851fb 100644 (file)
@@ -6,7 +6,7 @@
 
 // bump if necessary
 #define LLAMA_MAX_LAYERS  512
-#define LLAMA_MAX_EXPERTS 256  // DeepSeekV3
+#define LLAMA_MAX_EXPERTS 384  // Kimi-K2
 
 enum llama_expert_gating_func_type {
     LLAMA_EXPERT_GATING_FUNC_TYPE_NONE    = 0,
@@ -98,7 +98,7 @@ struct llama_hparams {
     float    rope_freq_scale_train;
     float    rope_freq_scale_train_swa;
     uint32_t n_ctx_orig_yarn;
-    float    rope_yarn_log_mul;
+    float    rope_yarn_log_mul = 0.0f;
 
     std::array<int, 4> rope_sections;
 
@@ -191,6 +191,14 @@ struct llama_hparams {
     // dimension of value embeddings across all k-v heads
     uint32_t n_embd_v_gqa(uint32_t il = 0) const;
 
+    // true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
+    bool is_n_embd_k_gqa_variable() const;
+    bool is_n_embd_v_gqa_variable() const;
+
+    // return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
+    uint32_t n_embd_k_gqa_max() const;
+    uint32_t n_embd_v_gqa_max() const;
+
     // dimension of the rolling state embeddings
     // corresponds to Mamba's conv_states size or RWKV's token_shift states size
     uint32_t n_embd_r() const;
index fe207ad536032d34930bc6f4bbfc4a35b4226bb5..01d27fb4db9b1d8adb104432b8c5c64f3b2ece7c 100644 (file)
@@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
                      bool   v_trans,
                      bool   offload,
                      bool   swa_full,
+                     bool   unified,
                  uint32_t   kv_size,
                  uint32_t   n_seq_max,
                  uint32_t   n_ubatch,
-                 uint32_t   n_pad) : hparams(model.hparams) {
+                 uint32_t   n_pad) : hparams(model.hparams), unified(unified) {
     llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
     llama_kv_cache_unified::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
 
     const uint32_t size_base = kv_size;
 
-    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
+    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
 
     // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
     if (swa_full) {
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
 
     kv_base = std::make_unique<llama_kv_cache_unified>(
             model, std::move(filter_base), type_k, type_v,
-            v_trans, offload, size_base, n_seq_max, n_pad,
+            v_trans, offload, unified, size_base, n_seq_max, n_pad,
             0, LLAMA_SWA_TYPE_NONE);
 
     LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
 
     kv_swa = std::make_unique<llama_kv_cache_unified>(
             model, std::move(filter_swa), type_k, type_v,
-            v_trans, offload, size_swa, n_seq_max, n_pad,
+            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
             hparams.n_swa, hparams.swa_type);
 }
 
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
 
     // first try simple split
     do {
+        if (!unified) {
+            // requires equal splits, so we skip the simple split
+            break;
+        }
+
         balloc.split_reset();
 
         std::vector<llama_ubatch> ubatches;
@@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
 
         std::vector<llama_ubatch> ubatches;
         while (true) {
-            auto ubatch = balloc.split_equal(n_ubatch, false);
+            auto ubatch = balloc.split_equal(n_ubatch, !unified);
 
             if (ubatch.n_tokens == 0) {
                 break;
index 23205d826b23b231bb2026f082debe064f092500..d2650dadd3595b2614551e36941bffda8f6909af 100644 (file)
@@ -20,6 +20,7 @@ public:
                          bool   v_trans,
                          bool   offload,
                          bool   swa_full,
+                         bool   unified,
                      uint32_t   kv_size,
                      uint32_t   n_seq_max,
                      uint32_t   n_ubatch,
@@ -68,6 +69,8 @@ public:
 private:
     const llama_hparams & hparams;
 
+    const bool unified;
+
     std::unique_ptr<llama_kv_cache_unified> kv_base;
     std::unique_ptr<llama_kv_cache_unified> kv_swa;
 };
index d3129cc53281e6589ebd7e7a3cae1ea6407878c4..321dc79fc36ab708a4ac96076b3fabf200568a3e 100644 (file)
@@ -23,13 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
                 ggml_type    type_v,
                      bool    v_trans,
                      bool    offload,
+                     bool    unified,
                  uint32_t    kv_size,
                  uint32_t    n_seq_max,
                  uint32_t    n_pad,
                  uint32_t    n_swa,
            llama_swa_type    swa_type) :
     model(model), hparams(model.hparams), v_trans(v_trans),
-    n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
+    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
 
     GGML_ASSERT(kv_size % n_pad == 0);
 
@@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
             ggml_init_params params = {
-                /*.mem_size   =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
+                /*.mem_size   =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
             };
@@ -64,9 +65,33 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         return it->second;
     };
 
-    head = 0;
+    GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
 
-    cells.resize(kv_size);
+    v_heads.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_heads[s] = 0;
+    }
+
+    v_cells.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_cells[s].resize(kv_size);
+    }
+
+    // by default, all sequence ids are mapped to the 0th stream
+    seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
+
+    if (n_stream > 1) {
+        seq_to_stream.resize(n_stream, 0);
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            seq_to_stream[s] = s;
+        }
+    }
+
+    // [TAG_V_CACHE_VARIABLE]
+    if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
+        LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
+                __func__, hparams.n_embd_v_gqa_max());
+    }
 
     for (uint32_t il = 0; il < n_layer_cache; il++) {
         if (filter && !filter(il)) {
@@ -74,8 +99,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
             continue;
         }
 
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+        // [TAG_V_CACHE_VARIABLE]
+        const uint32_t n_embd_k_gqa =            hparams.n_embd_k_gqa(il);
+        const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
 
         const char * dev_name = "CPU";
 
@@ -98,14 +124,23 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         ggml_tensor * k;
         ggml_tensor * v;
 
-        k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
-        v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
+        k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
+        v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
 
         ggml_format_name(k, "cache_k_l%d", il);
         ggml_format_name(v, "cache_v_l%d", il);
 
+        std::vector<ggml_tensor *> k_stream;
+        std::vector<ggml_tensor *> v_stream;
+
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
+            v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
+        }
+
         map_layer_ids[il] = layers.size();
-        layers.push_back({ il, k, v });
+
+        layers.push_back({ il, k, v, k_stream, v_stream, });
     }
 
     // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
@@ -148,8 +183,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         const size_t memory_size_k = size_k_bytes();
         const size_t memory_size_v = size_v_bytes();
 
-        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
+        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
                 ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
                 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
     }
@@ -158,7 +193,12 @@ llama_kv_cache_unified::llama_kv_cache_unified(
     debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
 
     const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
-    supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
+    supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : 0;
+
+    if (!supports_set_rows) {
+        // ref: https://github.com/ggml-org/llama.cpp/pull/14363
+        GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
+    }
 
     if (!supports_set_rows) {
         LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
@@ -166,9 +206,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
 }
 
 void llama_kv_cache_unified::clear(bool data) {
-    cells.reset();
-
-    head = 0;
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_cells[s].reset();
+        v_heads[s] = 0;
+    }
 
     if (data) {
         for (auto & buf : bufs) {
@@ -178,6 +219,11 @@ void llama_kv_cache_unified::clear(bool data) {
 }
 
 bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+    auto & head  = v_heads[seq_to_stream[seq_id]];
+
     uint32_t new_head = cells.size();
 
     if (p0 < 0) {
@@ -224,30 +270,94 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
 }
 
 void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
-    if (seq_id_src == seq_id_dst) {
+    GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
+    GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
+
+    const auto s0 = seq_to_stream[seq_id_src];
+    const auto s1 = seq_to_stream[seq_id_dst];
+
+    if (s0 == s1) {
+        // since both sequences are in the same stream, no data copy is necessary
+        // we just have to update the cells meta data
+
+        auto & cells = v_cells[s0];
+
+        if (seq_id_src == seq_id_dst) {
+            return;
+        }
+
+        if (p0 < 0) {
+            p0 = 0;
+        }
+
+        if (p1 < 0) {
+            p1 = std::numeric_limits<llama_pos>::max();
+        }
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.pos_in(i, p0, p1)) {
+                continue;
+            }
+
+            if (cells.seq_has(i, seq_id_src)) {
+                cells.seq_add(i, seq_id_dst);
+            }
+        }
+
         return;
     }
 
-    if (p0 < 0) {
-        p0 = 0;
+    // cross-stream sequence copies require to copy the actual buffer data
+
+    bool is_full = true;
+
+    if (p0 > 0 && p0 + 1 < (int) get_size()) {
+        is_full = false;
     }
 
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
+    if (p1 > 0 && p1 + 1 < (int) get_size()) {
+        is_full = false;
     }
 
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (!cells.pos_in(i, p0, p1)) {
-            continue;
-        }
+    GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
+
+    // enqueue the copy operation - the buffer copy will be performed during the next update
+    sc_info.ssrc.push_back(s0);
+    sc_info.sdst.push_back(s1);
 
-        if (cells.seq_has(i, seq_id_src)) {
-            cells.seq_add(i, seq_id_dst);
+    v_cells[s1].reset();
+    for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
+        if (v_cells[s0].seq_has(i, seq_id_src)) {
+            llama_pos pos   = v_cells[s0].pos_get(i);
+            llama_pos shift = v_cells[s0].get_shift(i);
+
+            if (shift != 0) {
+                pos -= shift;
+                assert(pos >= 0);
+            }
+
+            v_cells[s1].pos_set(i, pos);
+            v_cells[s1].seq_add(i, seq_id_dst);
+
+            if (shift != 0) {
+                v_cells[s1].pos_add(i, shift);
+            }
         }
     }
+
+    v_heads[s1] = v_heads[s0];
+
+    //for (uint32_t s = 0; s < n_stream; ++s) {
+    //    LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
+    //}
 }
 
 void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+    auto & head  = v_heads[seq_to_stream[seq_id]];
+
     uint32_t new_head = cells.size();
 
     for (uint32_t i = 0; i < cells.size(); ++i) {
@@ -265,6 +375,11 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
 }
 
 void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+    auto & head  = v_heads[seq_to_stream[seq_id]];
+
     if (shift == 0) {
         return;
     }
@@ -304,6 +419,10 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
 }
 
 void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+
     if (d == 1) {
         return;
     }
@@ -333,10 +452,18 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
 }
 
 llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    const auto & cells = v_cells[seq_to_stream[seq_id]];
+
     return cells.seq_pos_min(seq_id);
 }
 
 llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    const auto & cells = v_cells[seq_to_stream[seq_id]];
+
     return cells.seq_pos_max(seq_id);
 }
 
@@ -351,7 +478,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
 
         std::vector<llama_ubatch> ubatches;
         while (true) {
-            auto ubatch = balloc.split_simple(n_ubatch);
+            auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
 
             if (ubatch.n_tokens == 0) {
                 break;
@@ -387,7 +514,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
     defrag_info dinfo;
 
     // see if we need to defrag
-    {
+    if (n_stream == 1) {
+        // note : for now do not consider defrag for n_stream > 1
+        const auto & cells = v_cells[seq_to_stream[0]];
+
         bool do_defrag = optimize;
 
         const auto thold = lctx->get_cparams().defrag_thold;
@@ -411,22 +541,22 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
         }
     }
 
-    return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
+    return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
 }
 
 llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
     llama_kv_cache_unified::slot_info_vec_t res;
 
-    struct state {
-        uint32_t head_old; // old position of the head, before placing the ubatch
-
+    struct state_t {
         slot_info sinfo; // slot info for the ubatch
 
-        llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
+        std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
+
+        std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
     };
 
     // remember the old state of the cells so we can restore it in the end
-    std::vector<state> states;
+    std::vector<state_t> states;
 
     bool success = true;
 
@@ -445,16 +575,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
         res.push_back(sinfo_new);
 
         // store the old state of the cells in the recovery stack
-        states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
+        {
+            state_t state = { sinfo_new, v_heads, {} };
+
+            for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
+                auto & cells = v_cells[sinfo_new.strm[s]];
+
+                state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
+            }
+
+            states.push_back(std::move(state));
+        }
 
         // now emplace the ubatch
         apply_ubatch(sinfo_new, ubatch);
     }
 
+    GGML_ASSERT(!states.empty() || !success);
+
     // iterate backwards and restore the cells to their original state
     for (auto it = states.rbegin(); it != states.rend(); ++it) {
-        cells.set(it->sinfo.idxs, it->cells);
-        head = it->head_old;
+        const auto & sinfo = it->sinfo;
+
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            auto & cells = v_cells[sinfo.strm[s]];
+            auto & head  = v_heads[sinfo.strm[s]];
+
+            cells.set(sinfo.idxs[s], it->v_cells[s]);
+            head = it->v_heads_old[s];
+        }
     }
 
     if (!success) {
@@ -464,11 +613,38 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
     return res;
 }
 
-bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
+bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
     bool updated = false;
 
     auto * sched = lctx->get_sched();
 
+    if (!sc_info.empty()) {
+        assert(n_stream > 1 && "stream copy should never happen with a single stream");
+
+        llama_synchronize(lctx);
+
+        const size_t n_copy = sc_info.ssrc.size();
+
+        for (size_t i = 0; i < n_copy; ++i) {
+            const auto ssrc = sc_info.ssrc[i];
+            const auto sdst = sc_info.sdst[i];
+
+            assert(ssrc < n_stream);
+            assert(sdst < n_stream);
+
+            LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
+
+            assert(ssrc != sdst);
+
+            for (uint32_t il = 0; il < layers.size(); ++il) {
+                const auto & layer = layers[il];
+
+                ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
+                ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
+            }
+        }
+    }
+
     if (do_shift) {
         if (!get_can_shift()) {
             GGML_ABORT("The current KV cache / model configuration does not support K-shift");
@@ -480,14 +656,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
         if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
             ggml_backend_sched_reset(sched);
 
-            auto * gf = lctx->graph_init();
+            auto * res = lctx->get_gf_res_reserve();
 
-            auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
-            if (!res) {
-                LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
-                return updated;
-            }
+            res->reset();
 
+            auto * gf = build_graph_shift(res, lctx);
             if (!ggml_backend_sched_alloc_graph(sched, gf)) {
                 LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
                 return updated;
@@ -503,12 +676,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
             updated = true;
         }
 
-        cells.reset_shift();
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            auto & cells = v_cells[s];
+
+            cells.reset_shift();
+        }
     }
 
     if (!dinfo.empty()) {
         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
 
+        // note: for now do not consider defrag for n_stream > 1
+        auto & cells = v_cells[seq_to_stream[0]];
+        auto & head  = v_heads[seq_to_stream[0]];
+
         // apply moves:
         {
             const auto n_kv = dinfo.ids.size();
@@ -529,14 +710,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
 
         ggml_backend_sched_reset(sched);
 
-        auto * gf = lctx->graph_init();
+        auto * res = lctx->get_gf_res_reserve();
 
-        auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
-        if (!res) {
-            LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
-            return updated;
-        }
+        res->reset();
 
+        auto * gf = build_graph_defrag(res, lctx, dinfo);
         if (!ggml_backend_sched_alloc_graph(sched, gf)) {
             LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
             return updated;
@@ -556,23 +734,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
 }
 
 llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
-    const uint32_t n_tokens = ubatch.n_tokens;
+    if (debug > 0) {
+        const auto & cells = v_cells[seq_to_stream[1]];
 
-    uint32_t head_cur = this->head;
+        const uint32_t head_cur = v_heads[1];
 
-    // if we have enough unused cells before the current head ->
-    //   better to start searching from the beginning of the cache, hoping to fill it
-    if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
-        head_cur = 0;
-    }
-
-    if (n_tokens > cells.size()) {
-        LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
-        return { };
-    }
-
-    if (debug > 0) {
-        LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
+        LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
+                __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
 
         if ((debug == 2 && n_swa > 0) || debug > 2) {
             std::string ss;
@@ -629,86 +797,133 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
         }
     }
 
-    uint32_t n_tested = 0;
+    uint32_t n_tokens = ubatch.n_tokens;
+    uint32_t n_seqs   = 1;
+
+    if (n_stream > 1) {
+        GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
 
-    // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
-    // for non-continuous slots, we test the tokens one by one
-    const uint32_t n_test = cont ? n_tokens : 1;
+        n_seqs   = ubatch.n_seqs_unq;
+        n_tokens = n_tokens / n_seqs;
+    }
 
-    slot_info res;
+    slot_info res = {
+        /*.s0   =*/ LLAMA_MAX_SEQ,
+        /*.s1   =*/ 0,
+        /*.strm =*/ { },
+        /*.idxs =*/ { },
+    };
 
-    auto & idxs = res.idxs;
+    res.resize(n_seqs);
 
-    idxs.reserve(n_tokens);
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const auto seq_id = ubatch.seq_id_unq[s];
 
-    while (true) {
-        if (head_cur + n_test > cells.size()) {
-            n_tested += cells.size() - head_cur;
+        if (n_stream > 1) {
+            GGML_ASSERT(ubatch.n_seq_id[s*n_tokens]    == 1);
+            GGML_ASSERT(ubatch.seq_id  [s*n_tokens][0] == seq_id);
+        }
+
+        res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
+        res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
+
+        res.strm[s] = seq_to_stream[seq_id];
+        res.idxs[s].reserve(n_tokens);
+
+        const auto & cells = v_cells[seq_to_stream[seq_id]];
+
+        uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
+
+        // if we have enough unused cells before the current head ->
+        //   better to start searching from the beginning of the cache, hoping to fill it
+        if (head_cur > cells.get_used() + 2*n_tokens) {
             head_cur = 0;
-            continue;
         }
 
-        for (uint32_t i = 0; i < n_test; i++) {
-            const auto idx = head_cur;
+        if (n_tokens > cells.size()) {
+            LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
+            return { };
+        }
+
+        uint32_t n_tested = 0;
+
+        // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
+        // for non-continuous slots, we test the tokens one by one
+        const uint32_t n_test = cont ? n_tokens : 1;
 
-            //const llama_pos    pos    = ubatch.pos[i];
-            //const llama_seq_id seq_id = ubatch.seq_id[i][0];
+        while (true) {
+            if (head_cur + n_test > cells.size()) {
+                n_tested += cells.size() - head_cur;
+                head_cur = 0;
+                continue;
+            }
 
-            // can we use this cell? either:
-            //  - the cell is empty
-            //  - the cell is occupied only by one sequence:
-            //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
-            //    - mask SWA, using current max pos for that sequence in the cache
-            //                always insert in the cell with minimum pos
-            bool can_use = cells.is_empty(idx);
+            for (uint32_t i = 0; i < n_test; i++) {
+                const auto idx = head_cur;
 
-            if (!can_use && cells.seq_count(idx) == 1) {
-                const llama_pos pos_cell = cells.pos_get(idx);
+                head_cur++;
+                n_tested++;
 
-                // (disabled) causal mask
-                // note: it's better to purge any "future" tokens beforehand
-                //if (cells.seq_has(idx, seq_id)) {
-                //    can_use = pos_cell >= pos;
-                //}
+                //const llama_pos    pos    = ubatch.pos[i];
+                //const llama_seq_id seq_id = ubatch.seq_id[i][0];
 
-                if (!can_use) {
-                    const llama_seq_id seq_id_cell = cells.seq_get(idx);
+                // can we use this cell? either:
+                //  - the cell is empty
+                //  - the cell is occupied only by one sequence:
+                //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
+                //    - mask SWA, using current max pos for that sequence in the cache
+                //                always insert in the cell with minimum pos
+                bool can_use = cells.is_empty(idx);
 
-                    // SWA mask
-                    if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
-                        can_use = true;
+                if (!can_use && cells.seq_count(idx) == 1) {
+                    const llama_pos pos_cell = cells.pos_get(idx);
+
+                    // (disabled) causal mask
+                    // note: it's better to purge any "future" tokens beforehand
+                    //if (cells.seq_has(idx, seq_id)) {
+                    //    can_use = pos_cell >= pos;
+                    //}
+
+                    if (!can_use) {
+                        const llama_seq_id seq_id_cell = cells.seq_get(idx);
+
+                        // SWA mask
+                        if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
+                            can_use = true;
+                        }
                     }
                 }
-            }
 
-            head_cur++;
-            n_tested++;
+                if (can_use) {
+                    res.idxs[s].push_back(idx);
+                } else {
+                    if (cont) {
+                        break;
+                    }
+                }
+            }
 
-            if (can_use) {
-                idxs.push_back(idx);
-            } else {
+            if (res.idxs[s].size() == n_tokens) {
                 break;
             }
-        }
 
-        if (idxs.size() == n_tokens) {
-            break;
-        }
+            if (cont) {
+                res.idxs[s].clear();
+            }
 
-        if (cont) {
-            idxs.clear();
+            if (n_tested >= cells.size()) {
+                //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
+                return { };
+            }
         }
 
-        if (n_tested >= cells.size()) {
-            //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
+        // we didn't find a suitable slot - return empty result
+        if (res.idxs[s].size() < n_tokens) {
             return { };
         }
     }
 
-    // we didn't find a suitable slot - return empty result
-    if (idxs.size() < n_tokens) {
-        res.clear();
-    }
+    assert(res.s1 >= res.s0);
 
     return res;
 }
@@ -717,41 +932,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
     // keep track of the max sequence position that we would overwrite with this ubatch
     // for non-SWA cache, this would be always empty
     llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
-    for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
+    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
         seq_pos_max_rm[s] = -1;
     }
 
-    assert(ubatch.n_tokens == sinfo.idxs.size());
+    assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
 
-    for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
-        const auto idx = sinfo.idxs.at(i);
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
+            const uint32_t i = s*sinfo.size() + ii;
 
-        if (!cells.is_empty(idx)) {
-            assert(cells.seq_count(idx) == 1);
+            auto & cells = v_cells[sinfo.strm[s]];
 
-            const llama_seq_id seq_id = cells.seq_get(idx);
-            const llama_pos    pos    = cells.pos_get(idx);
+            const auto idx = sinfo.idxs[s][ii];
 
-            seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+            if (!cells.is_empty(idx)) {
+                assert(cells.seq_count(idx) == 1);
 
-            cells.rm(idx);
-        }
+                const llama_seq_id seq_id = cells.seq_get(idx);
+                const llama_pos    pos    = cells.pos_get(idx);
 
-        cells.pos_set(idx, ubatch.pos[i]);
+                seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+
+                cells.rm(idx);
+            }
 
-        for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
-            cells.seq_add(idx, ubatch.seq_id[i][s]);
+            cells.pos_set(idx, ubatch.pos[i]);
+
+            for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
+                cells.seq_add(idx, ubatch.seq_id[i][s]);
+            }
         }
     }
 
     // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
     //       will be present in the cache. so we have to purge any position which is less than those we would overwrite
     //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
-    for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
+    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
         if (seq_pos_max_rm[s] == -1) {
             continue;
         }
 
+        GGML_ASSERT(s < seq_to_stream.size());
+
+        auto & cells = v_cells[seq_to_stream[s]];
+
         if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
             LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
                     __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
@@ -761,7 +986,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
     }
 
     // move the head at the end of the slot
-    head = sinfo.idxs.back() + 1;
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        auto & head = v_heads[sinfo.strm[s]];
+
+        head = sinfo.idxs[s].back() + 1;
+    }
 }
 
 bool llama_kv_cache_unified::get_can_shift() const {
@@ -769,49 +998,91 @@ bool llama_kv_cache_unified::get_can_shift() const {
 }
 
 uint32_t llama_kv_cache_unified::get_size() const {
+    const auto & cells = v_cells[seq_to_stream[0]];
+
     return cells.size();
 }
 
+uint32_t llama_kv_cache_unified::get_n_stream() const {
+    return n_stream;
+}
+
 bool llama_kv_cache_unified::get_has_shift() const {
-    return cells.get_has_shift();
+    bool result = false;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        result |= v_cells[s].get_has_shift();
+    }
+
+    return result;
 }
 
 uint32_t llama_kv_cache_unified::get_n_kv() const {
-    return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
+    uint32_t result = 0;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        const auto & cells = v_cells[s];
+
+        result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
+    }
+
+    return result;
 }
 
-ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
+bool llama_kv_cache_unified::get_supports_set_rows() const {
+    return supports_set_rows;
+}
+
+ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * k = layers[ikv].k;
 
-    return ggml_view_3d(ctx, k,
-            hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
+    const uint64_t kv_size      = get_size();
+    const uint64_t n_embd_k_gqa = k->ne[0];
+
+    assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
+
+    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
+
+    return ggml_view_4d(ctx, k,
+            hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
             ggml_row_size(k->type, hparams.n_embd_head_k),
-            ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
-            0);
+            ggml_row_size(k->type, n_embd_k_gqa),
+            ggml_row_size(k->type, n_embd_k_gqa*kv_size),
+            ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
 }
 
-ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
+ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * v = layers[ikv].v;
 
+    const uint64_t kv_size      = get_size();
+    const uint64_t n_embd_v_gqa = v->ne[0];
+
+    // [TAG_V_CACHE_VARIABLE]
+    assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
+
+    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
+
     if (!v_trans) {
         // note: v->nb[1] <= v->nb[2]
-        return ggml_view_3d(ctx, v,
-                hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
-                ggml_row_size(v->type, hparams.n_embd_head_v),    // v->nb[1]
-                ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
-                0);
+        return ggml_view_4d(ctx, v,
+                hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
+                ggml_row_size(v->type, hparams.n_embd_head_v),            // v->nb[1]
+                ggml_row_size(v->type, n_embd_v_gqa),         // v->nb[2]
+                ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
+                ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
     }
 
     // note: v->nb[1] > v->nb[2]
-    return ggml_view_3d(ctx, v,
-            n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
-            ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
-            ggml_row_size(v->type, v->ne[1]),                       // v->nb[2]
-            0);
+    return ggml_view_4d(ctx, v,
+            n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
+            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),    // v->nb[1]
+            ggml_row_size(v->type, kv_size),                          // v->nb[2]
+            ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
+            ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
 }
 
 ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -825,12 +1096,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
     k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
 
     if (k_idxs && supports_set_rows) {
+        if (k->ne[2] > 1) {
+            k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
+        }
+
         return ggml_set_rows(ctx, k, k_cur, k_idxs);
     }
 
     // TODO: fallback to old ggml_cpy() method for backwards compatibility
     //       will be removed when ggml_set_rows() is adopted by all backends
 
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
+
     ggml_tensor * k_view = ggml_view_1d(ctx, k,
             n_tokens*n_embd_k_gqa,
             ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
@@ -843,37 +1120,38 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
 
     auto * v = layers[ikv].v;
 
-    const int64_t n_embd_v_gqa = v->ne[0];
-    const int64_t n_tokens = v_cur->ne[2];
+    const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
+    const int64_t n_tokens     = v_cur->ne[2];
 
     v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
 
     if (v_idxs && supports_set_rows) {
         if (!v_trans) {
+            if (v->ne[2] > 1) {
+                v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
+            }
+
             return ggml_set_rows(ctx, v, v_cur, v_idxs);
         }
 
-        // the row becomes a single element
-        ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
+        // [TAG_V_CACHE_VARIABLE]
+        if (n_embd_v_gqa < v->ne[0]) {
+            v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
+        }
 
-        // note: the V cache is transposed when not using flash attention
-        v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
+        // the row becomes a single element
+        ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
 
-        // note: we can be more explicit here at the cost of extra cont
-        //       however, above we take advantage that a row of single element is always continuous regardless of the row stride
-        //v_cur = ggml_transpose(ctx, v_cur);
-        //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
+        v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
 
-        // we broadcast the KV indices n_embd_v_gqa times
-        // v      [1,        n_kv,     n_embd_v_gqa]
-        // v_cur  [1,        n_tokens, n_embd_v_gqa]
-        // v_idxs [n_tokens, 1,        1]
         return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
     }
 
     // TODO: fallback to old ggml_cpy() method for backwards compatibility
     //       will be removed when ggml_set_rows() is adopted by all backends
 
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
+
     ggml_tensor * v_view = nullptr;
 
     if (!v_trans) {
@@ -904,7 +1182,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
 ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
     const uint32_t n_tokens = ubatch.n_tokens;
 
-    ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
+    ggml_tensor * v_idxs;
+
+    if (!v_trans) {
+        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
+    } else {
+        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
+    }
 
     ggml_set_input(v_idxs);
 
@@ -917,12 +1201,17 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
     }
 
     const uint32_t n_tokens = ubatch->n_tokens;
+    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
 
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     int64_t * data = (int64_t *) dst->data;
 
-    for (int64_t i = 0; i < n_tokens; ++i) {
-        data[i] = sinfo.idxs.at(i);
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        const int64_t offs = sinfo.strm[s]*get_size();
+
+        for (uint32_t i = 0; i < sinfo.size(); ++i) {
+            data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
+        }
     }
 }
 
@@ -932,12 +1221,48 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
     }
 
     const uint32_t n_tokens = ubatch->n_tokens;
+    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
 
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     int64_t * data = (int64_t *) dst->data;
 
-    for (int64_t i = 0; i < n_tokens; ++i) {
-        data[i] = sinfo.idxs.at(i);
+    if (!v_trans) {
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            const int64_t offs = sinfo.strm[s]*get_size();
+
+            for (uint32_t i = 0; i < sinfo.size(); ++i) {
+                data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
+            }
+        }
+    } else {
+        // note: the V cache is transposed when not using flash attention
+        const int64_t kv_size = get_size();
+
+        const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
+
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
+
+            for (uint32_t i = 0; i < sinfo.size(); ++i) {
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
+                }
+            }
+        }
+    }
+}
+
+void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+
+    int32_t * data = (int32_t *) dst->data;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        const auto & cells = v_cells[s];
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
+        }
     }
 }
 
@@ -947,7 +1272,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     float * data = (float *) dst->data;
 
-    const int64_t n_kv = dst->ne[0];
+    const int64_t n_kv     = dst->ne[0];
+    const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
+
+    GGML_ASSERT(n_tokens%n_stream == 0);
+
+    // n_tps == n_tokens_per_stream
+    const int64_t n_tps     = n_tokens/n_stream;
+    const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
+
+    std::fill(data, data + ggml_nelements(dst), -INFINITY);
 
     // Use only the previous KV cells of the correct sequence for each token of the ubatch.
     // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -961,70 +1295,57 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     //      xxxxx-----
     //      xxxxx-----
     // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
+    // TODO: optimize this section
     for (uint32_t h = 0; h < 1; ++h) {
-        for (uint32_t i = 0; i < n_tokens; ++i) {
-            const llama_seq_id seq_id = ubatch->seq_id[i][0];
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            for (uint32_t ii = 0; ii < n_tps; ++ii) {
+                const uint32_t i = s*n_tps + ii;
 
-            const llama_pos p1 = ubatch->pos[i];
+                const llama_seq_id seq_id = ubatch->seq_id[i][0];
 
-            for (uint32_t j = 0; j < n_kv; ++j) {
-                float f = 0.0f;
+                const auto & cells = v_cells[seq_to_stream[seq_id]];
 
-                bool masked = false;
+                const llama_pos p1 = ubatch->pos[i];
 
-                if (cells.is_empty(j)) {
-                    masked = true;
-                } else {
-                    const llama_pos p0 = cells.pos_get(j);
+                const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
+
+                for (uint32_t j = 0; j < n_kv; ++j) {
+                    if (cells.is_empty(j)) {
+                        continue;
+                    }
 
                     // mask the token if not the same sequence
-                    masked = masked || (!cells.seq_has(j, seq_id));
+                    if (!cells.seq_has(j, seq_id)) {
+                        continue;
+                    }
+
+                    const llama_pos p0 = cells.pos_get(j);
 
                     // mask future tokens
-                    masked = masked || (causal_attn && p0 > p1);
+                    if (causal_attn && p0 > p1) {
+                        continue;
+                    }
 
                     // apply SWA if any
-                    masked = masked || (is_masked_swa(p0, p1));
-
-                    if (!masked && hparams.use_alibi) {
-                        f = -std::abs(p0 - p1);
+                    if (is_masked_swa(p0, p1)) {
+                        continue;
                     }
-                }
-
-                if (masked) {
-                    f = -INFINITY;
-                }
-
-                data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
-            }
-        }
 
-        // mask padded tokens
-        if (data) {
-            for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                for (uint32_t j = 0; j < n_kv; ++j) {
-                    data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                    data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
                 }
             }
         }
     }
 }
 
-void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-
-    int32_t * data = (int32_t *) dst->data;
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
-    }
-}
-
 void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     const int64_t n_tokens = ubatch->n_tokens;
 
+    GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
+    const auto & cells = v_cells[0];
+
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
 
     int32_t * data = (int32_t *) dst->data;
 
@@ -1129,7 +1450,7 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
-    ggml_tensor * k_shift; // I32 [kv_size]
+    ggml_tensor * k_shift; // I32 [kv_size*n_stream]
 
     const llama_kv_cache_unified * kv_self;
 };
@@ -1142,20 +1463,20 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
-        const llama_cparams & cparams,
-               ggml_context * ctx,
-                ggml_cgraph * gf) const {
-    auto res = std::make_unique<llm_graph_result>();
+ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
+    auto * ctx = res->get_ctx();
+    auto * gf  = res->get_gf();
 
     const auto & n_embd_head_k = hparams.n_embd_head_k;
   //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
     auto inp = std::make_unique<llm_graph_input_k_shift>(this);
 
-    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
+    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
     ggml_set_input(inp->k_shift);
 
+    const auto & cparams = lctx->get_cparams();
+
     for (const auto & layer : layers) {
         const uint32_t il = layer.il;
 
@@ -1169,7 +1490,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
 
         ggml_tensor * k =
             ggml_view_3d(ctx, layer.k,
-                n_embd_head_k, n_head_kv, cells.size(),
+                n_embd_head_k, n_head_kv, get_size()*n_stream,
                 ggml_row_size(layer.k->type, n_embd_head_k),
                 ggml_row_size(layer.k->type, n_embd_k_gqa),
                 0);
@@ -1181,18 +1502,24 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
 
     res->add_input(std::move(inp));
 
-    return res;
+    return gf;
 }
 
-llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
-                const llama_cparams & cparams,
-                       ggml_context * ctx,
-                        ggml_cgraph * gf,
-                  const defrag_info & dinfo) const {
-    auto res = std::make_unique<llm_graph_result>();
+ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
+         llm_graph_result * res,
+            llama_context * lctx,
+        const defrag_info & dinfo) const {
+    auto * ctx = res->get_ctx();
+    auto * gf  = res->get_gf();
+
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
+
+    const auto & cells = v_cells[0];
 
     const auto & ids = dinfo.ids;
 
+    const auto & cparams = lctx->get_cparams();
+
 #if 0
     // CPU defrag
     //
@@ -1329,10 +1656,14 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
     //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
 #endif
 
-    return res;
+    return gf;
 }
 
 llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
+    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
+
+    const auto & cells = v_cells[0];
+
     const uint32_t n_layer = layers.size();
 
     const uint32_t n_kv   = cells.used_max_p1();
@@ -1478,64 +1809,94 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
 }
 
 void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
-    std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
-    uint32_t cell_count = 0;
+    io.write(&n_stream, sizeof(n_stream));
 
-    // Count the number of cells with the specified seq_id
-    // Find all the ranges of cells with this seq id (or all, when -1)
-    uint32_t cell_range_begin = cells.size();
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        cell_ranges_t cr { s, {} };
 
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
-            ++cell_count;
-            if (cell_range_begin == cells.size()) {
-                cell_range_begin = i;
-            }
-        } else {
-            if (cell_range_begin != cells.size()) {
-                cell_ranges.emplace_back(cell_range_begin, i);
-                cell_range_begin = cells.size();
+        uint32_t cell_count = 0;
+
+        const auto & cells = v_cells[s];
+
+        // Count the number of cells with the specified seq_id
+        // Find all the ranges of cells with this seq id (or all, when -1)
+        uint32_t cell_range_begin = cells.size();
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
+                ++cell_count;
+                if (cell_range_begin == cells.size()) {
+                    cell_range_begin = i;
+                }
+            } else {
+                if (cell_range_begin != cells.size()) {
+                    cr.data.emplace_back(cell_range_begin, i);
+                    cell_range_begin = cells.size();
+                }
             }
         }
-    }
 
-    if (cell_range_begin != cells.size()) {
-        cell_ranges.emplace_back(cell_range_begin, cells.size());
-    }
+        if (cell_range_begin != cells.size()) {
+            cr.data.emplace_back(cell_range_begin, cells.size());
+        }
 
-    // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-    uint32_t cell_count_check = 0;
-    for (const auto & range : cell_ranges) {
-        cell_count_check += range.second - range.first;
-    }
-    GGML_ASSERT(cell_count == cell_count_check);
+        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+        uint32_t cell_count_check = 0;
+        for (const auto & range : cr.data) {
+            cell_count_check += range.second - range.first;
+        }
+        GGML_ASSERT(cell_count == cell_count_check);
 
-    io.write(&cell_count, sizeof(cell_count));
+        io.write(&cell_count, sizeof(cell_count));
 
-    state_write_meta(io, cell_ranges, seq_id);
-    state_write_data(io, cell_ranges);
+        // skip empty streams
+        if (cell_count == 0) {
+            continue;
+        }
+
+        state_write_meta(io, cr, seq_id);
+        state_write_data(io, cr);
+    }
 }
 
 void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
-    uint32_t cell_count;
-    io.read_to(&cell_count, sizeof(cell_count));
+    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
 
-    bool res = true;
-    res = res && state_read_meta(io, cell_count, seq_id);
-    res = res && state_read_data(io, cell_count);
+    uint32_t n_stream_cur;
+    io.read_to(&n_stream_cur, sizeof(n_stream_cur));
+    if (n_stream_cur != n_stream) {
+        throw std::runtime_error("n_stream mismatch");
+    }
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        uint32_t cell_count;
+        io.read_to(&cell_count, sizeof(cell_count));
+
+        if (cell_count == 0) {
+            continue;
+        }
 
-    if (!res) {
-        if (seq_id == -1) {
-            clear(true);
-        } else {
-            seq_rm(seq_id, -1, -1);
+        const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
+
+        bool res = true;
+        res = res && state_read_meta(io, strm, cell_count, seq_id);
+        res = res && state_read_data(io, strm, cell_count);
+
+        if (!res) {
+            if (seq_id == -1) {
+                clear(true);
+            } else {
+                seq_rm(seq_id, -1, -1);
+            }
+            throw std::runtime_error("failed to restore kv cache");
         }
-        throw std::runtime_error("failed to restore kv cache");
     }
 }
 
-void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
-    for (const auto & range : cell_ranges) {
+void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
+    const auto & cells = v_cells[cr.strm];
+
+    for (const auto & range : cr.data) {
         for (uint32_t i = range.first; i < range.second; ++i) {
             std::vector<llama_seq_id> seq_ids;
 
@@ -1560,7 +1921,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
     }
 }
 
-void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
+void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
+    const auto & cells = v_cells[cr.strm];
+
     const uint32_t v_trans = this->v_trans ? 1 : 0;
     const uint32_t n_layer = layers.size();
 
@@ -1576,19 +1939,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
 
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
+        auto * k = layer.k_stream[cr.strm];
+
         // Write key type
-        const int32_t k_type_i = (int32_t)layer.k->type;
+        const int32_t k_type_i = (int32_t) k->type;
         io.write(&k_type_i, sizeof(k_type_i));
 
         // Write row size of key
-        const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
+        const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
         io.write(&k_size_row, sizeof(k_size_row));
 
         // Read each range of cells of k_size length each into tmp_buf and write out
-        for (const auto & range : cell_ranges) {
+        for (const auto & range : cr.data) {
             const size_t range_size = range.second - range.first;
             const size_t buf_size = range_size * k_size_row;
-            io.write_tensor(layer.k, range.first * k_size_row, buf_size);
+            io.write_tensor(k, range.first * k_size_row, buf_size);
         }
     }
 
@@ -1598,19 +1963,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
 
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
+            auto * v = layer.v_stream[cr.strm];
+
             // Write value type
-            const int32_t v_type_i = (int32_t)layer.v->type;
+            const int32_t v_type_i = (int32_t) v->type;
             io.write(&v_type_i, sizeof(v_type_i));
 
             // Write row size of value
-            const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
+            const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
             io.write(&v_size_row, sizeof(v_size_row));
 
             // Read each range of cells of v_size length each into tmp_buf and write out
-            for (const auto & range : cell_ranges) {
+            for (const auto & range : cr.data) {
                 const size_t range_size = range.second - range.first;
                 const size_t buf_size = range_size * v_size_row;
-                io.write_tensor(layer.v, range.first * v_size_row, buf_size);
+                io.write_tensor(v, range.first * v_size_row, buf_size);
             }
         }
     } else {
@@ -1622,12 +1989,14 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
 
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
+            auto * v = layer.v_stream[cr.strm];
+
             // Write value type
-            const int32_t v_type_i = (int32_t)layer.v->type;
+            const int32_t v_type_i = (int32_t) v->type;
             io.write(&v_type_i, sizeof(v_type_i));
 
             // Write element size
-            const uint32_t v_size_el = ggml_type_size(layer.v->type);
+            const uint32_t v_size_el = ggml_type_size(v->type);
             io.write(&v_size_el, sizeof(v_size_el));
 
             // Write GQA embedding size
@@ -1636,27 +2005,31 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
             // For each row, we get the element values of each cell
             for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
                 // Read each range of cells of v_size_el length each into tmp_buf and write out
-                for (const auto & range : cell_ranges) {
+                for (const auto & range : cr.data) {
                     const size_t range_size = range.second - range.first;
                     const size_t src_offset = (range.first + j * kv_size) * v_size_el;
                     const size_t buf_size = range_size * v_size_el;
-                    io.write_tensor(layer.v, src_offset, buf_size);
+                    io.write_tensor(v, src_offset, buf_size);
                 }
             }
         }
     }
 }
 
-bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
+bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
+    auto & cells = v_cells[strm];
+    auto & head  = v_heads[strm];
+
     if (dest_seq_id != -1) {
         // single sequence
-
         seq_rm(dest_seq_id, -1, -1);
 
         llama_batch_allocr balloc(hparams.n_pos_per_embd());
 
         llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
 
+        ubatch.seq_id_unq[0] = dest_seq_id;
+
         for (uint32_t i = 0; i < cell_count; ++i) {
             llama_pos pos;
             uint32_t n_seq_id;
@@ -1693,6 +2066,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         // keep the head at the old position because we will read the KV data into it in state_read_data()
         head = head_cur;
 
+        LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
+
         // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
         // Assume that this is one contiguous block of cells
         GGML_ASSERT(head_cur + cell_count <= cells.size());
@@ -1738,7 +2113,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
     return true;
 }
 
-bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
+bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
+    auto & cells = v_cells[strm];
+    auto & head  = v_heads[strm];
+
     uint32_t v_trans;
     uint32_t n_layer;
 
@@ -1766,10 +2144,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
+        auto * k = layer.k_stream[strm];
+
         // Read type of key
         int32_t k_type_i_ref;
         io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
-        const int32_t k_type_i = (int32_t) layer.k->type;
+        const int32_t k_type_i = (int32_t) k->type;
         if (k_type_i != k_type_i_ref) {
             LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
             return false;
@@ -1778,7 +2158,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         // Read row size of key
         uint64_t k_size_row_ref;
         io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
-        const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
+        const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
         if (k_size_row != k_size_row_ref) {
             LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
             return false;
@@ -1786,7 +2166,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
         if (cell_count) {
             // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+            ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
         }
     }
 
@@ -1796,10 +2176,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
+            auto * v = layer.v_stream[strm];
+
             // Read type of value
             int32_t v_type_i_ref;
             io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t)layer.v->type;
+            const int32_t v_type_i = (int32_t) v->type;
             if (v_type_i != v_type_i_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
                 return false;
@@ -1808,7 +2190,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
             // Read row size of value
             uint64_t v_size_row_ref;
             io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
-            const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
+            const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
             if (v_size_row != v_size_row_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
                 return false;
@@ -1816,7 +2198,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
             if (cell_count) {
                 // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+                ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
             }
         }
     } else {
@@ -1826,10 +2208,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
+            auto * v = layer.v_stream[strm];
+
             // Read type of value
             int32_t v_type_i_ref;
             io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t)layer.v->type;
+            const int32_t v_type_i = (int32_t) v->type;
             if (v_type_i != v_type_i_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
                 return false;
@@ -1838,7 +2222,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
             // Read element size of value
             uint32_t v_size_el_ref;
             io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
-            const size_t v_size_el = ggml_type_size(layer.v->type);
+            const size_t v_size_el = ggml_type_size(v->type);
             if (v_size_el != v_size_el_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
                 return false;
@@ -1856,7 +2240,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
                 // For each row in the transposed matrix, read the values for the whole cell range
                 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
                     const size_t dst_offset = (head + j * cells.size()) * v_size_el;
-                    ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                    ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
                 }
             }
         }
@@ -1875,18 +2259,26 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
     n_kv = kv->get_size();
 
+    const uint32_t n_stream = kv->get_n_stream();
+
     // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
     sinfos.resize(1);
-    sinfos[0].idxs.resize(1);
-    sinfos[0].idxs[0] = 0;
+    sinfos[0].s0 = 0;
+    sinfos[0].s1 = n_stream - 1;
+    sinfos[0].idxs.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        sinfos[0].strm.push_back(s);
+        sinfos[0].idxs[s].resize(1, 0);
+    }
 }
 
 llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv,
         llama_context * lctx,
         bool do_shift,
-        defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
-    if (!do_shift && this->dinfo.empty()) {
+        defrag_info dinfo,
+        stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
+    if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
         status = LLAMA_MEMORY_STATUS_NO_UPDATE;
     }
 }
@@ -1914,7 +2306,7 @@ bool llama_kv_cache_unified_context::apply() {
 
     // no ubatches -> this is a KV cache update
     if (ubatches.empty()) {
-        kv->update(lctx, do_shift, dinfo);
+        kv->update(lctx, do_shift, dinfo, sc_info);
 
         return true;
     }
@@ -1940,12 +2332,16 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
     return n_kv;
 }
 
+bool llama_kv_cache_unified_context::get_supports_set_rows() const {
+    return kv->get_supports_set_rows();
+}
+
 ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
-    return kv->get_k(ctx, il, n_kv);
+    return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
 }
 
 ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
-    return kv->get_v(ctx, il, n_kv);
+    return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
 }
 
 ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
index b8b0356e830c89d84483e6afc474f6cf7492e1f2..3e28e346c3fcf8d1ec09fa15de1b2bd6c4dcb3b4 100644 (file)
@@ -35,16 +35,50 @@ public:
         std::vector<uint32_t> ids;
     };
 
+    struct stream_copy_info {
+        bool empty() const {
+            assert(ssrc.size() == sdst.size());
+            return ssrc.empty();
+        }
+
+        std::vector<uint32_t> ssrc;
+        std::vector<uint32_t> sdst;
+    };
+
     // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
     //   KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
     struct slot_info {
         // data for ggml_set_rows
         using idx_vec_t = std::vector<uint32_t>;
 
-        idx_vec_t idxs;
+        // number of streams: ns = s1 - s0 + 1
+        llama_seq_id s0;
+        llama_seq_id s1;
+
+        std::vector<llama_seq_id> strm; // [ns]
+        std::vector<idx_vec_t>    idxs; // [ns]
 
         uint32_t head() const {
-            return idxs.at(0);
+            GGML_ASSERT(idxs.size() == 1);
+            GGML_ASSERT(!idxs[0].empty());
+
+            return idxs[0][0];
+        }
+
+        void resize(size_t n) {
+            strm.resize(n);
+            idxs.resize(n);
+        }
+
+        size_t size() const {
+            GGML_ASSERT(idxs.size() == strm.size());
+            GGML_ASSERT(!idxs.empty());
+
+            return idxs[0].size();
+        }
+
+        size_t n_stream() const {
+            return strm.size();
         }
 
         bool empty() const {
@@ -54,9 +88,6 @@ public:
         void clear() {
             idxs.clear();
         }
-
-        // TODO: implement
-        //std::vector<idx_vec_t> seq_idxs;
     };
 
     using slot_info_vec_t = std::vector<slot_info>;
@@ -68,6 +99,7 @@ public:
                     ggml_type    type_v,
                          bool    v_trans,
                          bool    offload,
+                         bool    unified,
                      uint32_t    kv_size,
                      uint32_t    n_seq_max,
                      uint32_t    n_pad,
@@ -111,7 +143,8 @@ public:
     // llama_kv_cache_unified specific API
     //
 
-    uint32_t get_size() const;
+    uint32_t get_size()     const;
+    uint32_t get_n_stream() const;
 
     bool get_has_shift() const;
 
@@ -121,9 +154,12 @@ public:
 
     uint32_t get_n_kv() const;
 
+    // TODO: temporary
+    bool get_supports_set_rows() const;
+
     // get views of the current state of the cache
-    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
-    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
 
     // store k_cur and v_cur in the cache based on the provided head location
     ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
@@ -137,7 +173,7 @@ public:
     // return empty vector on failure
     slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
 
-    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
+    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
 
     // find a slot of kv cells that can hold the ubatch
     // if cont == true, then the slot must be continuous
@@ -157,8 +193,9 @@ public:
     void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
     void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
 
+    void set_input_k_shift(ggml_tensor * dst) const;
+
     void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
-    void set_input_k_shift   (ggml_tensor * dst) const;
     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
 private:
@@ -172,15 +209,15 @@ private:
 
         ggml_tensor * k;
         ggml_tensor * v;
+
+        std::vector<ggml_tensor *> k_stream;
+        std::vector<ggml_tensor *> v_stream;
     };
 
     bool v_trans = true;  // the value tensor is transposed
 
-    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
-    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
-    uint32_t head = 0;
-
     const uint32_t n_seq_max = 1;
+    const uint32_t n_stream  = 1;
 
     // required padding
     const uint32_t n_pad = 1;
@@ -193,14 +230,24 @@ private:
 
     // env: LLAMA_SET_ROWS (temporary)
     // ref: https://github.com/ggml-org/llama.cpp/pull/14285
-    int supports_set_rows = false;
+    bool supports_set_rows = false;
 
     const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
 
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
-    llama_kv_cells_unified cells;
+    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
+    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
+    std::vector<uint32_t> v_heads;
+
+    std::vector<llama_kv_cells_unified> v_cells;
+
+    // maps from a sequence id to a stream id
+    std::vector<uint32_t> seq_to_stream;
+
+    // pending stream copies that will be applied during the next update
+    stream_copy_info sc_info;
 
     std::vector<kv_layer> layers;
 
@@ -226,29 +273,34 @@ private:
                           float   freq_base,
                           float   freq_scale) const;
 
-    llm_graph_result_ptr build_graph_shift(
-            const llama_cparams & cparams,
-                   ggml_context * ctx,
-                    ggml_cgraph * gf) const;
+    ggml_cgraph * build_graph_shift(
+               llm_graph_result * res,
+                  llama_context * lctx) const;
 
-    llm_graph_result_ptr build_graph_defrag(
-            const llama_cparams & cparams,
-                   ggml_context * ctx,
-                    ggml_cgraph * gf,
+    ggml_cgraph * build_graph_defrag(
+               llm_graph_result * res,
+                  llama_context * lctx,
               const defrag_info & dinfo) const;
 
-    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
-    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
+    struct cell_ranges_t {
+        uint32_t strm;
 
-    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
-    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
+        std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
+    };
+
+    void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
+    void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
+
+    bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
 };
 
 class llama_kv_cache_unified_context : public llama_memory_context_i {
 public:
     // some shorthands
-    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
-    using defrag_info     = llama_kv_cache_unified::defrag_info;
+    using slot_info_vec_t  = llama_kv_cache_unified::slot_info_vec_t;
+    using defrag_info      = llama_kv_cache_unified::defrag_info;
+    using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
 
     // used for errors
     llama_kv_cache_unified_context(llama_memory_status status);
@@ -262,7 +314,8 @@ public:
             llama_kv_cache_unified * kv,
             llama_context * lctx,
             bool do_shift,
-            defrag_info dinfo);
+            defrag_info dinfo,
+            stream_copy_info sc_info);
 
     // used to create a batch procesing context from a batch
     llama_kv_cache_unified_context(
@@ -288,6 +341,9 @@ public:
 
     uint32_t get_n_kv() const;
 
+    // TODO: temporary
+    bool get_supports_set_rows() const;
+
     // get views of the current state of the cache
     ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
     ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
@@ -320,6 +376,8 @@ private:
 
     defrag_info dinfo;
 
+    stream_copy_info sc_info;
+
     //
     // batch processing context
     //
index 6cd10db06b77571675aa2a6dfa6ab8977552d199..d8e2086c87514f34116709cc0f87682457a1f587 100644 (file)
@@ -38,6 +38,7 @@ llama_memory_hybrid::llama_memory_hybrid(
         type_v,
         v_trans,
         offload,
+        1,
         kv_size,
         n_seq_max,
         n_pad,
index 2c1ae67098ca49445ca0305e6b71cda23f09f0ce..c0c2ec084dc1447787e9c3e95aa64f572244eb0c 100644 (file)
@@ -446,7 +446,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
     // A slot should be always be contiguous.
 
     // can only process batches with an equal number of new tokens in each sequence
-    GGML_ASSERT(ubatch.equal_seqs);
+    GGML_ASSERT(ubatch.equal_seqs());
 
     int32_t min = size - 1;
     int32_t max = 0;
@@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
     // Iterate and write all the keys first, each row is a cell
     // Get whole range at a time
     for (uint32_t il = 0; il < n_layer; ++il) {
+        // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
+        if (r_l[il] == nullptr) continue;
 
         // Write key type
         const int32_t r_type_i = (int32_t)r_l[il]->type;
@@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
 
     if (!s_trans) {
         for (uint32_t il = 0; il < n_layer; ++il) {
+            // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
+            if (s_l[il] == nullptr) continue;
 
             // Write value type
             const int32_t s_type_i = (int32_t)s_l[il]->type;
@@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
         // When v is transposed, we also need the element size and get the element ranges from each row
         const uint32_t mem_size = size;
         for (uint32_t il = 0; il < n_layer; ++il) {
+            // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
+            if (s_l[il] == nullptr) continue;
+
             const uint32_t n_embd_s = hparams.n_embd_s();
 
             // Write value type
@@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
 
     // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
     for (uint32_t il = 0; il < n_layer; ++il) {
+        // skip null layers
+        if (r_l[il] == nullptr) continue;
 
         // Read type of key
         int32_t r_type_i_ref;
@@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
 
     if (!s_trans) {
         for (uint32_t il = 0; il < n_layer; ++il) {
+            // skip null layers
+            if (s_l[il] == nullptr) continue;
 
             // Read type of value
             int32_t s_type_i_ref;
             io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
             const int32_t s_type_i = (int32_t)s_l[il]->type;
+
             if (s_type_i != s_type_i_ref) {
                 LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
                 return false;
@@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
     } else {
         // For each layer, read the values for each cell (transposed)
         for (uint32_t il = 0; il < n_layer; ++il) {
+            // skip null layers
+            if (s_l[il] == nullptr) continue;
+
             const uint32_t n_embd_s = hparams.n_embd_s();
 
             // Read type of value
index a322fc39352e7f2a1138f4c053df224399f90e69..71f89e19072ded81e794f7c781ec0f077719475e 100644 (file)
@@ -107,8 +107,10 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_17B_16E:       return "17Bx16E (Scout)";
         case LLM_TYPE_17B_128E:      return "17Bx128E (Maverick)";
         case LLM_TYPE_A13B:          return "A13B";
+        case LLM_TYPE_21B_A3B:       return "21B.A3B";
         case LLM_TYPE_30B_A3B:       return "30B.A3B";
         case LLM_TYPE_235B_A22B:     return "235B.A22B";
+        case LLM_TYPE_300B_A47B:     return "300B.A47B";
         case LLM_TYPE_E2B:           return "E2B";
         case LLM_TYPE_E4B:           return "E4B";
         default:                     return "?B";
@@ -644,6 +646,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_RESIDUAL_SCALE,              hparams.f_residual_scale);
                 ml.get_key(LLM_KV_LOGIT_SCALE,                 hparams.f_logit_scale);
 
+                // MiniCPM uses rope by default, unlike Granite which uses it as a switch
+                hparams.rope_finetuned = true;
+
                 switch (hparams.n_layer) {
                     case 52: type = LLM_TYPE_1B; break;
                     case 40: type = LLM_TYPE_2B; break;
@@ -849,6 +854,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_DREAM:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                // Dream models are primarily 7B with 28 layers
+                switch (hparams.n_layer) {
+                    case 28:
+                        type = LLM_TYPE_7B;
+                        break;
+                    default:
+                        type = LLM_TYPE_UNKNOWN;
+                }
+                // Set non-causal attention for diffusion models
+                hparams.causal_attn = false;
+            }
+            break;
         case LLM_ARCH_QWEN2MOE:
             {
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp, false);
@@ -935,6 +955,33 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                }
             } break;
+        case LLM_ARCH_PLAMO2:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                // Load Mamba SSM parameters
+                ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv);
+                ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
+                ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
+                ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+                ml.get_key(LLM_KV_SSM_GROUP_COUNT,    hparams.ssm_n_group);
+
+                for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+                    hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0;
+                }
+
+                switch (hparams.n_layer) {
+                    case 16: type = LLM_TYPE_1B; break;
+                    case 32:
+                        if (hparams.n_embd == 2048) {
+                            type = LLM_TYPE_2B;
+                        } else if (hparams.n_embd == 4096) {
+                            type = LLM_TYPE_8B;
+                        }
+                        break;
+                    default: type = LLM_TYPE_UNKNOWN;
+               }
+            } break;
         case LLM_ARCH_GPT2:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1322,7 +1369,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     // that have no expert_gating_func model parameter set
                     hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
                 }
-                ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
 
                 switch (hparams.n_layer) {
                     case 27: type = LLM_TYPE_16B; break;
@@ -1446,6 +1493,23 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_EXAONE4:
+            {
+                if (hparams.n_layer == 64) {    // 32B
+                    hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+                    hparams.n_swa = 4096;
+                    hparams.set_swa_pattern(4);
+                }
+
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa, false);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 30: type = LLM_TYPE_1_2B; break;
+                    case 64: type = LLM_TYPE_32B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_RWKV6:
         case LLM_ARCH_RWKV6QWEN2:
             {
@@ -1483,7 +1547,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT,                      hparams.token_shift_count, false);
 
                 switch (hparams.n_layer) {
-                    case 12: type = LLM_TYPE_190M; break;
+                    case 12:
+                        switch (hparams.n_embd) {
+                            case 768: type = LLM_TYPE_190M; break;
+                            default: type = LLM_TYPE_UNKNOWN;
+                        } break;
                     case 24:
                         switch (hparams.n_embd) {
                             case 1024: type = LLM_TYPE_450M; break;
@@ -1496,7 +1564,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                             case 3584: type = LLM_TYPE_7B; break;
                             default: type = LLM_TYPE_UNKNOWN;
                         } break;
-                    case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World
+                    case 32:
+                        switch (hparams.n_embd) {
+                            case 2560: type = LLM_TYPE_2_9B; break;
+                            case 4096: type = LLM_TYPE_7B; break;
+                            default: type = LLM_TYPE_UNKNOWN;
+                        } break;
+                    case 61:
+                        switch (hparams.n_embd) {
+                            case 4096: type = LLM_TYPE_14B; break;
+                            default: type = LLM_TYPE_UNKNOWN;
+                        } break;
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
@@ -1607,10 +1685,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 }
             } break;
         case LLM_ARCH_ERNIE4_5:
+        case LLM_ARCH_ERNIE4_5_MOE:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                if (arch == LLM_ARCH_ERNIE4_5_MOE) {
+                    ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp);
+                    ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
+                    ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP,         hparams.n_moe_layer_step);
+                    ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT,         hparams.n_layer_dense_lead);
+                }
+
                 switch (hparams.n_layer) {
                     case 18: type = LLM_TYPE_0_3B; break;
+                    case 28: type = LLM_TYPE_21B_A3B; break;
+                    case 54: type = LLM_TYPE_300B_A47B; break;
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
@@ -2643,12 +2731,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                 } break;
             case LLM_ARCH_QWEN2:
             case LLM_ARCH_QWEN2VL:
+            case LLM_ARCH_DREAM:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
                     // output
                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
                     output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    output_b    = create_tensor(tn(LLM_TENSOR_OUTPUT,      "bias"),   {n_vocab}, TENSOR_NOT_REQUIRED);
                     // if output is NULL, init from the input tok embed
                     if (output == NULL) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
@@ -2938,6 +3028,73 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_PLAMO2:
+                {
+                    const uint32_t d_conv             = hparams.ssm_d_conv;
+                    const uint32_t d_state            = hparams.ssm_d_state;
+                    const uint32_t num_heads          = hparams.ssm_dt_rank;
+                    const uint32_t intermediate_size  = hparams.ssm_d_inner;
+                    const uint32_t head_dim           = intermediate_size / num_heads;
+                    const uint32_t qk_dim             = head_dim;
+                    const uint32_t v_dim              = head_dim;
+                    const int64_t num_attention_heads = hparams.n_head();
+                    const int64_t q_num_heads         = num_attention_heads;
+                    const int64_t dt_dim              = std::max(64, int(hparams.n_embd / 16));
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+                        bool is_mamba_layer = hparams.is_recurrent(i);
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (is_mamba_layer) {
+                            layer.ssm_in       = create_tensor(tn(LLM_TENSOR_SSM_IN,     "weight", i), {n_embd, 2 * intermediate_size}, 0);
+                            layer.ssm_conv1d   = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0);
+
+                            layer.ssm_x    = create_tensor(tn(LLM_TENSOR_SSM_X,  "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0);
+                            layer.ssm_dt   = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0);
+                            layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0);
+
+                            layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0);
+                            layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0);
+
+                            layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0);
+
+                            layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0);
+                            layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
+                            layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
+                        } else {
+                            const int64_t num_key_value_heads = hparams.n_head_kv(i);
+                            const int64_t k_num_heads         = num_key_value_heads;
+                            const int64_t v_num_heads         = num_key_value_heads;
+                            const int64_t q_proj_dim          = q_num_heads * qk_dim;
+                            const int64_t k_proj_dim          = k_num_heads * qk_dim;
+                            const int64_t v_proj_dim          = v_num_heads * v_dim;
+
+                            layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0);
+                            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0);
+                            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0);
+                        }
+
+                        // All layers have post-attention norm, FFN norm, and FFN tensors
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0);
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0);
+                    }
+                } break;
             case LLM_ARCH_GPT2:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4232,6 +4389,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,     "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_EXAONE4:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_post_norm  = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+                    }
+                } break;
             case LLM_ARCH_RWKV6:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4747,6 +4937,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     }
                 } break;
             case LLM_ARCH_ERNIE4_5:
+            case LLM_ARCH_ERNIE4_5_MOE:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
@@ -4775,9 +4966,27 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
 
                         layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        if (arch == LLM_ARCH_ERNIE4_5_MOE && static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers
+                            int n_ff_exp = hparams.n_ff_exp;
+
+                            layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff_exp, n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff_exp, n_expert}, 0);
+
+                            // Shared expert (if present)
+                            if (hparams.n_ff_shexp > 0) {
+                                layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {    n_embd, hparams.n_ff_shexp}, 0);
+                                layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd    }, 0);
+                                layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {    n_embd, hparams.n_ff_shexp}, 0);
+                            }
+                        } else { // Dense layers
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        }
                     }
                 } break;
             case LLM_ARCH_FALCON_H1:
@@ -5209,6 +5418,7 @@ void llama_model::print_info() const {
         arch == LLM_ARCH_MAMBA2 ||
         arch == LLM_ARCH_JAMBA ||
         arch == LLM_ARCH_FALCON_H1 ||
+        arch == LLM_ARCH_PLAMO2 ||
         arch == LLM_ARCH_GRANITE_HYBRID) {
         LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
         LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
@@ -5381,7 +5591,7 @@ ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int i
 }
 
 struct llm_build_llama : public llm_graph_context {
-    llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5457,7 +5667,7 @@ struct llm_build_llama : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
@@ -5537,7 +5747,7 @@ struct llm_build_llama : public llm_graph_context {
 };
 
 struct llm_build_llama_iswa : public llm_graph_context {
-    llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5631,7 +5841,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
                     cb(Kcur, "Kcur_normed", il);
                 }
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
@@ -5720,7 +5930,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
 };
 
 struct llm_build_deci : public llm_graph_context {
-    llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5808,7 +6018,7 @@ struct llm_build_deci : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
             }
@@ -5876,7 +6086,7 @@ struct llm_build_deci : public llm_graph_context {
 };
 
 struct llm_build_baichuan : public llm_graph_context {
-    llm_build_baichuan(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5940,7 +6150,7 @@ struct llm_build_baichuan : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -5998,7 +6208,7 @@ struct llm_build_baichuan : public llm_graph_context {
 };
 
 struct llm_build_xverse : public llm_graph_context {
-    llm_build_xverse(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_xverse(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6055,7 +6265,7 @@ struct llm_build_xverse : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -6111,7 +6321,7 @@ struct llm_build_xverse : public llm_graph_context {
 };
 
 struct llm_build_falcon : public llm_graph_context {
-    llm_build_falcon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -6178,7 +6388,7 @@ struct llm_build_falcon : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -6233,7 +6443,7 @@ struct llm_build_falcon : public llm_graph_context {
 };
 
 struct llm_build_grok : public llm_graph_context {
-    llm_build_grok(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6308,7 +6518,7 @@ struct llm_build_grok : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
@@ -6395,7 +6605,7 @@ struct llm_build_grok : public llm_graph_context {
 };
 
 struct llm_build_dbrx : public llm_graph_context {
-    llm_build_dbrx(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -6457,7 +6667,7 @@ struct llm_build_dbrx : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -6520,7 +6730,7 @@ struct llm_build_dbrx : public llm_graph_context {
 };
 
 struct llm_build_starcoder : public llm_graph_context {
-    llm_build_starcoder(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -6571,7 +6781,7 @@ struct llm_build_starcoder : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -6629,7 +6839,7 @@ struct llm_build_starcoder : public llm_graph_context {
 };
 
 struct llm_build_refact : public llm_graph_context {
-    llm_build_refact(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_refact(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6670,7 +6880,7 @@ struct llm_build_refact : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -6728,7 +6938,7 @@ struct llm_build_refact : public llm_graph_context {
 };
 
 struct llm_build_bert : public llm_graph_context {
-    llm_build_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -6827,7 +7037,7 @@ struct llm_build_bert : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
                 cb(cur, "kqv_out", il);
@@ -6914,7 +7124,7 @@ struct llm_build_bert : public llm_graph_context {
 };
 
 struct llm_build_neo_bert : public llm_graph_context {
-    llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -6972,7 +7182,7 @@ struct llm_build_neo_bert : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, nullptr,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
                 cb(cur, "kqv_out", il);
@@ -7024,7 +7234,7 @@ struct llm_build_neo_bert : public llm_graph_context {
 };
 
 struct llm_build_bloom : public llm_graph_context {
-    llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -7072,7 +7282,7 @@ struct llm_build_bloom : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -7130,7 +7340,7 @@ struct llm_build_bloom : public llm_graph_context {
 };
 
 struct llm_build_mpt : public llm_graph_context {
-    llm_build_mpt(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -7219,7 +7429,7 @@ struct llm_build_mpt : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -7278,7 +7488,7 @@ struct llm_build_mpt : public llm_graph_context {
 };
 
 struct llm_build_stablelm : public llm_graph_context {
-    llm_build_stablelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_stablelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7365,7 +7575,7 @@ struct llm_build_stablelm : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -7430,7 +7640,7 @@ struct llm_build_stablelm : public llm_graph_context {
 };
 
 struct llm_build_qwen : public llm_graph_context {
-    llm_build_qwen(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7486,7 +7696,7 @@ struct llm_build_qwen : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -7544,7 +7754,7 @@ struct llm_build_qwen : public llm_graph_context {
 };
 
 struct llm_build_qwen2 : public llm_graph_context {
-    llm_build_qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_qwen2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7606,7 +7816,7 @@ struct llm_build_qwen2 : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -7654,6 +7864,113 @@ struct llm_build_qwen2 : public llm_graph_context {
         // lm_head
         cur = build_lora_mm(model.output, cur);
 
+        if (model.output_b != nullptr) {
+            cur = ggml_add(ctx0, cur, model.output_b);
+        }
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_dream : public llm_graph_context {
+    llm_build_dream(const llama_model & model, const llm_graph_params & params) :
+        llm_graph_context(params) {
+        //copied from qwen2
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_no_cache();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                Qcur               = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                Kcur               = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                Vcur               = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                     ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                     ext_factor, attn_factor, beta_fast, beta_slow);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr,
+                                 nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
+                            model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
         cb(cur, "result_output", -1);
         res->t_logits = cur;
 
@@ -7662,7 +7979,7 @@ struct llm_build_qwen2 : public llm_graph_context {
 };
 
 struct llm_build_qwen2vl : public llm_graph_context {
-    llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7727,7 +8044,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -7783,7 +8100,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
 };
 
 struct llm_build_qwen2moe : public llm_graph_context {
-    llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7854,7 +8171,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -7942,7 +8259,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
 };
 
 struct llm_build_qwen3 : public llm_graph_context {
-    llm_build_qwen3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_qwen3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8007,7 +8324,7 @@ struct llm_build_qwen3 : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -8063,7 +8380,7 @@ struct llm_build_qwen3 : public llm_graph_context {
 };
 
 struct llm_build_qwen3moe : public llm_graph_context {
-    llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8128,7 +8445,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -8191,7 +8508,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
 };
 
 struct llm_build_phi2 : public llm_graph_context {
-    llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -8268,7 +8585,7 @@ struct llm_build_phi2 : public llm_graph_context {
                 // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
                 Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
@@ -8322,7 +8639,7 @@ struct llm_build_phi2 : public llm_graph_context {
 
 template<bool iswa>
 struct llm_build_phi3 : public llm_graph_context {
-    llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
 
@@ -8405,7 +8722,7 @@ struct llm_build_phi3 : public llm_graph_context {
                 Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
                 cb(Qcur, "Qcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
@@ -8480,7 +8797,7 @@ struct llm_build_phi3 : public llm_graph_context {
 };
 
 struct llm_build_plamo : public llm_graph_context {
-    llm_build_plamo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_plamo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8539,7 +8856,7 @@ struct llm_build_plamo : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -8595,7 +8912,7 @@ struct llm_build_plamo : public llm_graph_context {
 };
 
 struct llm_build_gpt2 : public llm_graph_context {
-    llm_build_gpt2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -8647,7 +8964,7 @@ struct llm_build_gpt2 : public llm_graph_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -8705,7 +9022,7 @@ struct llm_build_gpt2 : public llm_graph_context {
 };
 
 struct llm_build_codeshell : public llm_graph_context {
-    llm_build_codeshell(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -8761,7 +9078,7 @@ struct llm_build_codeshell : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -8819,7 +9136,7 @@ struct llm_build_codeshell : public llm_graph_context {
 };
 
 struct llm_build_orion : public llm_graph_context {
-    llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_orion(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8890,7 +9207,7 @@ struct llm_build_orion : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -8946,7 +9263,7 @@ struct llm_build_orion : public llm_graph_context {
 };
 
 struct llm_build_internlm2 : public llm_graph_context {
-    llm_build_internlm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9017,7 +9334,7 @@ struct llm_build_internlm2 : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -9073,7 +9390,7 @@ struct llm_build_internlm2 : public llm_graph_context {
 };
 
 struct llm_build_minicpm3 : public llm_graph_context {
-    llm_build_minicpm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_minicpm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         //TODO: if the model varies, these parameters need to be read from the model
         const int64_t n_embd_base = 256;
         const float scale_embd  = 12.0f;
@@ -9205,7 +9522,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
                 ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
                 cb(k_states, "k_states", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
             }
@@ -9277,7 +9594,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
 };
 
 struct llm_build_gemma : public llm_graph_context {
-    llm_build_gemma(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         ggml_tensor * cur;
@@ -9335,7 +9652,7 @@ struct llm_build_gemma : public llm_graph_context {
                 Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
                 cb(Qcur, "Qcur_scaled", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
@@ -9393,7 +9710,7 @@ struct llm_build_gemma : public llm_graph_context {
 };
 
 struct llm_build_gemma2_iswa : public llm_graph_context {
-    llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_k;
 
         ggml_tensor * cur;
@@ -9450,7 +9767,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
 
                 Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
@@ -9523,7 +9840,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
 };
 
 struct llm_build_gemma3_iswa : public llm_graph_context {
-    llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_k;
 
         ggml_tensor * cur;
@@ -9592,7 +9909,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
                 // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
                 Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
@@ -9661,7 +9978,6 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
 
 struct llm_build_gemma3n_iswa : public llm_graph_context {
     const llama_model & model;
-    ggml_cgraph * gf;
 
     const int64_t n_embd_head;
     const int64_t n_embd_altup;
@@ -9671,10 +9987,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
     const int     n_layer_sparsity = 10; // number of layers using activation sparsity
     const float   f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
 
-    llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
+    llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params)
             : llm_graph_context(params),
               model(model),
-              gf(gf),
               n_embd_head(model.hparams.n_embd_head_k),
               n_embd_altup(model.hparams.n_embd_altup),
               n_altup(model.hparams.n_altup),
@@ -9775,7 +10090,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
                 cb(Qcur, "Qcur_pos", il);
                 cb(Kcur, "Kcur_pos", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
             } else {
@@ -9793,7 +10108,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur_pos", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                     model.layers[il].wo, NULL,
                     Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
             }
@@ -10087,7 +10402,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
 
 // TODO: move up next to build_starcoder
 struct llm_build_starcoder2 : public llm_graph_context {
-    llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10158,7 +10473,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -10219,7 +10534,6 @@ struct llm_graph_context_mamba : public llm_graph_context {
 
     ggml_tensor * build_mamba_layer(
         llm_graph_input_rs * inp,
-               ggml_cgraph * gf,
                ggml_tensor * cur,
          const llama_model & model,
         const llama_ubatch & ubatch,
@@ -10244,13 +10558,13 @@ struct llm_graph_context_mamba : public llm_graph_context {
         const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
         GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.equal_seqs());
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
         ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
         ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
 
-        ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
+        ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
         conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
 
         // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -10331,7 +10645,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
                 return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
             };
 
-            ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
+            ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
 
             // store last states
             ggml_build_forward_expand(gf,
@@ -10358,11 +10672,10 @@ struct llm_graph_context_mamba : public llm_graph_context {
 
     ggml_tensor * build_mamba2_layer(
         llm_graph_input_rs * inp,
-             ggml_cgraph * gf,
-             ggml_tensor * cur,
-       const llama_model & model,
-      const llama_ubatch & ubatch,
-                     int   il) const {
+               ggml_tensor * cur,
+         const llama_model & model,
+        const llama_ubatch & ubatch,
+                       int   il) const {
 
         const auto * mctx_cur = inp->mctx;
 
@@ -10379,13 +10692,13 @@ struct llm_graph_context_mamba : public llm_graph_context {
         const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
         GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.equal_seqs());
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
         ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
         ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
 
-        ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
+        ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
         conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
 
         // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -10455,7 +10768,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
                 return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
             };
 
-            ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
+            ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
 
             // store last states
             ggml_build_forward_expand(gf,
@@ -10491,7 +10804,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
 };
 
 struct llm_build_mamba : public llm_graph_context_mamba {
-    llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
+    llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
         ggml_tensor * cur;
         ggml_tensor * inpL;
 
@@ -10510,9 +10823,9 @@ struct llm_build_mamba : public llm_graph_context_mamba {
             cb(cur, "attn_norm", il);
 
             if (model.arch == LLM_ARCH_MAMBA2) {
-                cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il);
+                cur = build_mamba2_layer(rs_inp, cur, model, ubatch, il);
             } else {
-                cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il);
+                cur = build_mamba_layer(rs_inp, cur, model, ubatch, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -10548,7 +10861,7 @@ struct llm_build_mamba : public llm_graph_context_mamba {
 };
 
 struct llm_build_jamba : public llm_graph_context_mamba {
-    llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
+    llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         ggml_tensor * cur;
@@ -10568,7 +10881,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
             cb(cur, "attn_norm", il);
 
             if (n_head_kv == 0) {
-                cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il);
+                cur = build_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il);
             } else {
                 // Attention
 
@@ -10589,7 +10902,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
                 cb(Vcur, "Vcur", il);
 
                 // No RoPE :)
-                cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
+                cur = build_attn(inp_hybrid->get_attn(), model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -10657,7 +10970,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
 };
 
 struct llm_build_command_r : public llm_graph_context {
-    llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_command_r(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10745,7 +11058,7 @@ struct llm_build_command_r : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -10804,7 +11117,7 @@ struct llm_build_command_r : public llm_graph_context {
 };
 
 struct llm_build_cohere2_iswa : public llm_graph_context {
-    llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10880,7 +11193,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -10940,7 +11253,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
 //   * removed bias
 //   * removed MoE
 struct llm_build_olmo : public llm_graph_context {
-    llm_build_olmo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_olmo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11011,7 +11324,7 @@ struct llm_build_olmo : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, nullptr,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -11068,7 +11381,7 @@ struct llm_build_olmo : public llm_graph_context {
 };
 
 struct llm_build_olmo2 : public llm_graph_context {
-    llm_build_olmo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11131,7 +11444,7 @@ struct llm_build_olmo2 : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -11197,7 +11510,7 @@ struct llm_build_olmo2 : public llm_graph_context {
 //   * removed bias
 //   * added q, k norm
 struct llm_build_olmoe : public llm_graph_context {
-    llm_build_olmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_olmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11264,7 +11577,7 @@ struct llm_build_olmoe : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -11325,7 +11638,7 @@ struct llm_build_olmoe : public llm_graph_context {
 };
 
 struct llm_build_openelm : public llm_graph_context {
-    llm_build_openelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11397,7 +11710,7 @@ struct llm_build_openelm : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Qcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -11454,7 +11767,7 @@ struct llm_build_openelm : public llm_graph_context {
 };
 
 struct llm_build_gptneox : public llm_graph_context {
-    llm_build_gptneox(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -11509,7 +11822,7 @@ struct llm_build_gptneox : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -11600,7 +11913,7 @@ struct llm_build_gptneox : public llm_graph_context {
 };
 
 struct llm_build_arctic : public llm_graph_context {
-    llm_build_arctic(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11659,7 +11972,7 @@ struct llm_build_arctic : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -11738,7 +12051,7 @@ struct llm_build_arctic : public llm_graph_context {
 };
 
 struct llm_build_deepseek : public llm_graph_context {
-    llm_build_deepseek(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_deepseek(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11814,7 +12127,7 @@ struct llm_build_deepseek : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
             }
@@ -11900,7 +12213,7 @@ struct llm_build_deepseek : public llm_graph_context {
 };
 
 struct llm_build_deepseek2 : public llm_graph_context {
-    llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         bool is_lite = (hparams.n_layer == 27);
 
         const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
@@ -12042,7 +12355,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
                     cb(Vcur, "Vcur", il);
 
                     // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
-                    cur = build_attn(inp_attn, gf,
+                    cur = build_attn(inp_attn,
                             model.layers[il].wo, NULL,
                             Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il);
                 } else {
@@ -12076,7 +12389,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
                     cb(Kcur, "Kcur", il);
 
                     // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
-                    cur = build_attn(inp_attn, gf,
+                    cur = build_attn(inp_attn,
                             model.layers[il].wo, NULL,
                             Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
                 }
@@ -12163,7 +12476,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
 };
 
 struct llm_build_bitnet : public llm_graph_context {
-    llm_build_bitnet(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12243,7 +12556,7 @@ struct llm_build_bitnet : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         NULL, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
 
@@ -12323,7 +12636,7 @@ struct llm_build_bitnet : public llm_graph_context {
 };
 
 struct llm_build_t5_enc : public llm_graph_context {
-    llm_build_t5_enc(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_t5_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12366,7 +12679,7 @@ struct llm_build_t5_enc : public llm_graph_context {
                 ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
                 ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo_enc, nullptr,
                         Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
                 cb(cur, "kqv_out", il);
@@ -12424,7 +12737,7 @@ struct llm_build_t5_enc : public llm_graph_context {
 };
 
 struct llm_build_t5_dec : public llm_graph_context {
-    llm_build_t5_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_t5_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         //const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -12472,7 +12785,7 @@ struct llm_build_t5_dec : public llm_graph_context {
                 ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
                 ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b);
 
-                cur = build_attn(inp_attn_self, gf,
+                cur = build_attn(inp_attn_self,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
                 cb(cur, "kqv_out", il);
@@ -12504,7 +12817,7 @@ struct llm_build_t5_dec : public llm_graph_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
                 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc);
 
-                cur = build_attn(inp_attn_cross, gf,
+                cur = build_attn(inp_attn_cross,
                         model.layers[il].wo_cross, nullptr,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
                 cb(cur, "kqv_out", il);
@@ -12594,7 +12907,7 @@ struct llm_build_t5_dec : public llm_graph_context {
 };
 
 struct llm_build_jais : public llm_graph_context {
-    llm_build_jais(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -12636,7 +12949,7 @@ struct llm_build_jais : public llm_graph_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
             }
@@ -12689,7 +13002,7 @@ struct llm_build_jais : public llm_graph_context {
 };
 
 struct llm_build_chatglm : public llm_graph_context {
-    llm_build_chatglm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -12768,7 +13081,7 @@ struct llm_build_chatglm : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -12822,7 +13135,7 @@ struct llm_build_chatglm : public llm_graph_context {
 };
 
 struct llm_build_glm4 : public llm_graph_context {
-    llm_build_glm4(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -12901,7 +13214,7 @@ struct llm_build_glm4 : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -12973,7 +13286,7 @@ struct llm_build_glm4 : public llm_graph_context {
 };
 
 struct llm_build_nemotron : public llm_graph_context {
-    llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -13045,7 +13358,7 @@ struct llm_build_nemotron : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -13102,7 +13415,7 @@ struct llm_build_nemotron : public llm_graph_context {
 };
 
 struct llm_build_exaone : public llm_graph_context {
-    llm_build_exaone(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -13176,7 +13489,7 @@ struct llm_build_exaone : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -13232,32 +13545,168 @@ struct llm_build_exaone : public llm_graph_context {
     }
 };
 
-struct llm_build_rwkv6_base : public llm_graph_context {
-    const llama_model & model;
+template <bool iswa>
+struct llm_build_exaone4 : public llm_graph_context {
+    llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_k;
 
-    llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
-    }
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-    ggml_tensor * build_rwkv6_channel_mix(
-            const llama_layer * layer,
-            ggml_tensor * cur,
-            ggml_tensor * x_prev,
-            llm_arch arch) const {
-        ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
-        switch (arch) {
-            case LLM_ARCH_RWKV6:
-                {
-                    ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
-                    ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur);
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
 
-                    ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr));
-                    ggml_tensor * k = ggml_sqr(
-                            ctx0,
-                            ggml_relu(
-                                ctx0,
-                                build_lora_mm(layer->channel_mix_key, xk)
-                                )
-                            );
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        inp_attn_type * inp_attn = nullptr;
+
+        if constexpr (iswa) {
+            inp_attn = build_attn_inp_kv_unified_iswa();
+        } else {
+            inp_attn = build_attn_inp_kv_unified();
+        }
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // use RoPE for SWA layers or non-SWA models
+            const bool use_rope = hparams.is_swa(il) || hparams.swa_type == LLAMA_SWA_TYPE_NONE;
+
+            cur = inpL;
+
+            // self-attention
+            {
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+                cb(Kcur, "Kcur_normed", il);
+
+                if (use_rope) {
+                    Qcur = ggml_rope_ext(
+                            ctx0, Qcur, inp_pos, rope_factors,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                            );
+
+                    Kcur = ggml_rope_ext(
+                            ctx0, Kcur, inp_pos, rope_factors,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                            );
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_ffn(ffn_inp,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_rwkv6_base : public llm_graph_context {
+    const llama_model & model;
+
+    llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
+    }
+
+    ggml_tensor * build_rwkv6_channel_mix(
+            const llama_layer * layer,
+            ggml_tensor * cur,
+            ggml_tensor * x_prev,
+            llm_arch arch) const {
+        ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
+        switch (arch) {
+            case LLM_ARCH_RWKV6:
+                {
+                    ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
+                    ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur);
+
+                    ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr));
+                    ggml_tensor * k = ggml_sqr(
+                            ctx0,
+                            ggml_relu(
+                                ctx0,
+                                build_lora_mm(layer->channel_mix_key, xk)
+                                )
+                            );
                     cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k));
                 } break;
             default:
@@ -13269,7 +13718,6 @@ struct llm_build_rwkv6_base : public llm_graph_context {
 
     ggml_tensor * build_rwkv6_time_mix(
             llm_graph_input_rs * inp,
-            ggml_cgraph * gf,
             ggml_tensor * cur,
             ggml_tensor * x_prev,
             const llama_ubatch & ubatch,
@@ -13396,7 +13844,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
         }
 
         ggml_tensor * wkv_state = build_rs(
-                inp, gf, mctx_cur->get_s_l(il),
+                inp, mctx_cur->get_s_l(il),
                 hparams.n_embd_s(), n_seqs);
 
         ggml_tensor * wkv_output;
@@ -13442,7 +13890,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
 };
 
 struct llm_build_rwkv6 : public llm_build_rwkv6_base {
-    llm_build_rwkv6(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
+    llm_build_rwkv6(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) {
         GGML_ASSERT(hparams.token_shift_count == 2);
 
         ggml_tensor * cur;
@@ -13463,7 +13911,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il);
 
             ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
             ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -13478,7 +13926,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
+            cur = build_rwkv6_time_mix(rs_inp, att_norm, x_prev, ubatch, il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
@@ -13543,7 +13991,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
 
 // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
 struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
-    llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
+    llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) {
         GGML_ASSERT(n_embd == hparams.n_embd_r());
 
         ggml_tensor * cur;
@@ -13563,7 +14011,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il);
 
             ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
             cb(att_norm, "attn_norm", il);
@@ -13575,7 +14023,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
+            cur = build_rwkv6_time_mix(rs_inp, att_norm, x_prev, ubatch, il);
 
             token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -13665,7 +14113,6 @@ struct llm_build_rwkv7_base : public llm_graph_context {
 
     ggml_tensor * build_rwkv7_time_mix(
             llm_graph_input_rs * inp,
-            ggml_cgraph * gf,
             ggml_tensor * cur,
             ggml_tensor * x_prev,
             ggml_tensor *& first_layer_value,
@@ -13751,7 +14198,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
 
         ggml_tensor * wkv_state = build_rs(
-                inp, gf, mctx_cur->get_s_l(il),
+                inp, mctx_cur->get_s_l(il),
                 hparams.n_embd_s(), n_seqs);
 
         ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -13798,7 +14245,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
 };
 
 struct llm_build_rwkv7 : public llm_build_rwkv7_base {
-    llm_build_rwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
+    llm_build_rwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) {
         GGML_ASSERT(hparams.token_shift_count == 2);
 
         ggml_tensor * cur;
@@ -13820,7 +14267,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il);
 
             ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
             ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -13835,7 +14282,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
                     1
                     );
 
-            cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
+            cur = build_rwkv7_time_mix(rs_inp, att_norm, x_prev, v_first, ubatch, il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
@@ -13894,7 +14341,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
 
 
 struct llm_build_arwkv7 : public llm_build_rwkv7_base {
-    llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
+    llm_build_arwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) {
         GGML_ASSERT(n_embd == hparams.n_embd_r());
 
         ggml_tensor * cur;
@@ -13915,7 +14362,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il);
 
             ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
             cb(att_norm, "attn_norm", il);
@@ -13927,7 +14374,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
                     1
                     );
 
-            cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
+            cur = build_rwkv7_time_mix(rs_inp, att_norm, x_prev, v_first, ubatch, il);
 
             token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -13984,8 +14431,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
 struct llm_build_granite : public llm_graph_context {
     llm_build_granite(
         const llama_model & model,
-        const llm_graph_params & params,
-        ggml_cgraph * gf)
+        const llm_graph_params & params)
         : llm_graph_context(params) {
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -14019,7 +14465,7 @@ struct llm_build_granite : public llm_graph_context {
 
             // self-attention
             cur = build_attention_layer(
-                gf, cur, inp_pos, inp_attn,
+                cur, inp_pos, inp_attn,
                 model, n_embd_head, il);
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -14055,7 +14501,6 @@ struct llm_build_granite : public llm_graph_context {
     }
 
     ggml_tensor * build_attention_layer(
-              ggml_cgraph                     * gf,
               ggml_tensor                     * cur,
               ggml_tensor                     * inp_pos,
               llm_graph_input_attn_kv_unified * inp_attn,
@@ -14110,7 +14555,7 @@ struct llm_build_granite : public llm_graph_context {
         cb(Vcur, "Vcur", il);
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
-        cur = build_attn(inp_attn, gf,
+        cur = build_attn(inp_attn,
                 model.layers[il].wo, model.layers[il].bo,
                 Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
@@ -14198,11 +14643,9 @@ struct llm_build_granite : public llm_graph_context {
 };
 
 struct llm_build_granite_hybrid : public llm_graph_context_mamba {
-
     llm_build_granite_hybrid(
                  const llama_model & model,
-            const llm_graph_params & params,
-                       ggml_cgraph * gf) :
+            const llm_graph_params & params) :
         llm_graph_context_mamba(params) {
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -14234,11 +14677,11 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
 
             if (hparams.is_recurrent(il)) {
                 // ssm layer //
-                cur = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il);
+                cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il);
             } else {
                 // attention layer //
                 cur = build_attention_layer(
-                    gf, cur, inp_pos, inp->get_attn(), model,
+                    cur, inp_pos, inp->get_attn(), model,
                     n_embd_head, il);
             }
 
@@ -14277,7 +14720,6 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
     }
 
     ggml_tensor * build_attention_layer(
-              ggml_cgraph                     * gf,
               ggml_tensor                     * cur,
               ggml_tensor                     * inp_pos,
               llm_graph_input_attn_kv_unified * inp_attn,
@@ -14332,7 +14774,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
         cb(Vcur, "Vcur", il);
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
-        cur = build_attn(inp_attn, gf,
+        cur = build_attn(inp_attn,
                 model.layers[il].wo, model.layers[il].bo,
                 Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
@@ -14426,7 +14868,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
 //   * removed bias
 //   * removed MoE
 struct llm_build_chameleon : public llm_graph_context {
-    llm_build_chameleon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -14517,7 +14959,7 @@ struct llm_build_chameleon : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, nullptr,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -14603,7 +15045,7 @@ struct llm_build_chameleon : public llm_graph_context {
 };
 
 struct llm_build_wavtokenizer_dec : public llm_graph_context {
-    llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         ggml_tensor * cur;
         ggml_tensor * inpL;
 
@@ -14755,7 +15197,7 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context {
 };
 
 struct llm_build_plm : public llm_graph_context {
-    llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_plm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k));
 
         const uint32_t n_embd_head_qk_rope = hparams.n_rot;
@@ -14873,7 +15315,7 @@ struct llm_build_plm : public llm_graph_context {
                 ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
                 cb(k_states, "k_states", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
             }
@@ -14927,7 +15369,7 @@ struct llm_build_plm : public llm_graph_context {
 };
 
 struct llm_build_bailingmoe : public llm_graph_context {
-    llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         ggml_tensor * cur;
         ggml_tensor * inpL;
 
@@ -14996,7 +15438,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
             }
@@ -15071,7 +15513,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
 };
 
 struct llm_build_dots1 : public llm_graph_context {
-    llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -15136,7 +15578,7 @@ struct llm_build_dots1 : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -15221,7 +15663,7 @@ struct llm_build_dots1 : public llm_graph_context {
 };
 
 struct llm_build_ernie4_5 : public llm_graph_context {
-    llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -15291,7 +15733,7 @@ struct llm_build_ernie4_5 : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
@@ -15350,8 +15792,178 @@ struct llm_build_ernie4_5 : public llm_graph_context {
     }
 };
 
+struct llm_build_ernie4_5_moe : public llm_graph_context {
+    llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Ernie 4.5 MoE requires n_moe_layer_step > 0");
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+            // norm
+            {
+                cur = build_norm(inpL,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_norm", il);
+            }
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            bool is_moe_layer = static_cast<uint32_t>(il) >= hparams.n_layer_dense_lead && (il + 1) % hparams.n_moe_layer_step == 0;
+
+            if (!is_moe_layer) {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                ggml_tensor * moe_out = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        model.layers[il].ffn_exp_probs_b,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                // Shared expert (if present)
+                if (hparams.n_ff_shexp > 0) {
+                    ggml_tensor * ffn_shexp = build_ffn(cur,
+                        model.layers[il].ffn_up_shexp,   NULL, NULL,
+                        model.layers[il].ffn_gate_shexp, NULL, NULL,
+                        model.layers[il].ffn_down_shexp, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                } else {
+                    cur = moe_out;
+                }
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 struct llm_build_falcon_h1 : public llm_graph_context_mamba {
-    llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
+    llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         ggml_tensor * cur;
@@ -15407,7 +16019,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba {
             cb(Kcur, "Kcur-post-rope", il);
             cb(Vcur, "Vcur-post-rope", il);
 
-            ggml_tensor * attn_out = build_attn(inp->get_attn(), gf,
+            ggml_tensor * attn_out = build_attn(inp->get_attn(),
                     model.layers[il].wo, NULL,
                     Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
             cb(attn_out, "attn_out", il);
@@ -15418,7 +16030,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba {
             // Mamba2 layer
             cb(cur, "ssm_in", il);
 
-            ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il);
+            ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il);
             cb(ssm_out, "ssm_out", il);
 
             // // Aggregation
@@ -15476,8 +16088,321 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba {
     }
 };
 
+struct llm_build_plamo2 : public llm_graph_context_mamba {
+    llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        // {n_embd, n_tokens}
+        inpL = build_inp_embd(model.tok_embd);
+        cb(inpL, "embedding_output", -1);
+
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_hybrid = build_inp_mem_hybrid();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * residual = inpL;
+
+            // ggml_graph_add_node(gf, model.layers[il].attn_norm);
+            // cb(model.layers[il].attn_norm, "attn_norm", il);
+
+            // pre_mixer_norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+
+            // check if this layer is Mamba or Attention
+            bool is_mamba_layer = hparams.is_recurrent(il);
+
+            if (is_mamba_layer) {
+                // PLaMo-2 Mamba layer
+                cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il);
+            } else {
+                // PLaMo-2 Attention layer
+                cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, cur, model, il);
+            }
+
+            // post_mixer_norm
+            cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            // residual connection
+            cur = ggml_add(ctx0, cur, residual);
+            cb(cur, "attn_residual", il);
+            residual = cur;
+
+            // pre-ffn norm
+            cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_pre_norm", il);
+
+            // feed-forward network
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    NULL,                      NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+            cb(cur, "ffn_out", il);
+
+            // post ffn norm
+            cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_post_norm", il);
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+            }
+
+            // residual connection
+            cur = ggml_add(ctx0, cur, residual);
+            cb(cur, "ffn_residual", il);
+
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        // final norm
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+        cb(cur, "result_output", -1);
+
+        // Explicitly mark as output tensor to ensure proper backend assignment
+        ggml_set_output(cur);
+
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+private:
+    ggml_tensor * build_plamo2_attn_layer(
+            llm_graph_input_attn_kv_unified * inp,
+            ggml_tensor * inp_pos,
+            ggml_tensor * cur,
+            const llama_model & model,
+            int il) {
+
+        // self-attention
+        {
+            // PLaMo-2 uses combined QKV tensor
+            ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
+            cb(qkv, "wqkv", il);
+
+            // split QKV tensor into Q, K, V
+            const int64_t n_embd_head_q = hparams.n_embd_head_k;
+            const int64_t n_embd_head_k = hparams.n_embd_head_k;
+            const int64_t n_embd_head_v = hparams.n_embd_head_v;
+            int32_t n_head_kv = hparams.n_head_kv(il);
+
+            const int64_t q_offset = 0;
+            const int64_t k_offset = n_embd_head_q * n_head;
+            const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv;
+
+            ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
+            ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv));
+            ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)));
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens);
+
+            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+            cb(Qcur, "Qcur_normed", il);
+
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+            cb(Kcur, "Kcur_normed", il);
+
+            Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il);
+        }
+
+        cb(cur, "attn_out", il);
+
+        return cur;
+    }
+
+    ggml_tensor * build_plamo2_mamba_layer(
+         llm_graph_input_rs * inp,
+               ggml_tensor * cur,
+         const llama_model & model,
+        const llama_ubatch & ubatch,
+                       int   il) {
+
+        const auto * mctx_cur = inp->mctx;
+
+        const auto kv_head = mctx_cur->get_head();
+
+        const int64_t d_conv   = hparams.ssm_d_conv;
+        const int64_t d_inner  = hparams.ssm_d_inner;
+        const int64_t d_state  = hparams.ssm_d_state;
+        const int64_t n_heads  = hparams.ssm_dt_rank;
+        const int64_t head_dim = d_inner / n_heads;
+        const int64_t n_group  = hparams.ssm_n_group;
+        const int64_t n_seqs   = ubatch.n_seqs;
+
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs());
+        GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+        ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+        ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+        ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+        conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
+
+        // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+        cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+        // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
+        ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur);
+        cb(zx, "mamba_in_proj", il);
+        // {8192, 5, 1, 1} -> {8192, 1, 5, 1}
+        zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
+        zx = ggml_cont(ctx0, zx);
+        zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
+        cb(zx, "mamba_in_proj_out", il);
+
+        // split into z and x
+        // => {head_dim * n_heads, n_seq_tokens, n_seqs}
+        ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx));
+        x = ggml_cont(ctx0, x);
+        x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs);
+        // x = ggml_permute(ctx0, x, 0, 2, 1, 3);
+        cb(x, "mamba_x_split", il);
+
+        ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0);
+        cb(z, "mamba_z_split", il);
+
+        // conv1d
+        {
+            // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
+            ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
+            cb(conv_x, "mamba_conv1d_input", il);
+
+            // copy last (d_conv - 1) columns back into the state cache
+            ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
+                    conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
+
+            ggml_build_forward_expand(gf,
+                ggml_cpy(ctx0, last_conv,
+                    ggml_view_1d(ctx0, conv_states_all,
+                        (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
+                        kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
+            cb(conv_states_all, "mamba_conv1d_state", il);
+
+            // 1D convolution
+            x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
+            cb(x, "mamba_conv1d", il);
+
+            x = ggml_silu(ctx0, x);
+            cb(x, "mamba_conv1d_silu", il);
+        }
+
+        // SSM
+        {
+            // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
+            ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x);
+            cb(x_bcdt, "mamba_bcdt_proj", il);
+
+            // split into dt, B, C
+            const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
+            ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
+            ggml_tensor * C  = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state);
+            ggml_tensor * dt  = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state));
+            cb(B, "mamba_B_raw", il);
+            cb(C, "mamba_C_raw", il);
+            cb(dt, "mamba_dt_raw", il);
+
+            // Apply RMS norm to dt, B, C (PLaMo-2 specific)
+            B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il);
+            C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il);
+            dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il);
+            cb(B, "mamba_B_normed", il);
+            cb(C, "mamba_C_normed", il);
+            cb(dt, "mamba_dt_normed", il);
+
+            // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
+            dt = build_lora_mm(model.layers[il].ssm_dt, dt);
+            dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
+            cb(dt, "mamba_dt_proj", il);
+
+            ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads);
+            cb(A, "mamba_A", il);
+
+            x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
+            B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0);
+            C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0);
+
+            // use the states and the indices provided by build_recurrent_state
+            // (this is necessary in order to properly use the states before they are overwritten,
+            //  while avoiding to make unnecessary copies of the states)
+            auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
+                ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size());
+
+                // Custom operator to optimize the parallel associative scan
+                // as described in the Annex D of the Mamba paper.
+                // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+                return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
+            };
+
+            ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
+            cb(y_ssm, "mamba_ssm_scan", il);
+
+            // store last states
+            ggml_build_forward_expand(gf,
+                ggml_cpy(ctx0,
+                    ggml_view_1d(ctx0, y_ssm, n_heads*head_dim*d_state*n_seqs, n_heads*head_dim*n_seq_tokens*n_seqs*ggml_element_size(y_ssm)),
+                    ggml_view_1d(ctx0, ssm_states_all, n_heads*head_dim*d_state*n_seqs, kv_head*n_seqs*n_heads*head_dim*d_state*ggml_element_size(ssm_states_all))));
+            cb(ssm_states_all, "mamba_ssm_states", il);
+
+            ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
+            cb(y, "mamba_y_view", il);
+
+            // Add D parameter and apply gating with z
+            // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
+            ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads);
+            y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D));
+            cb(y, "mamba_y_add_d", il);
+
+            y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
+            cb(y, "mamba_y_swiglu_z", il);
+
+            // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+            y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0);
+            cur = build_lora_mm(model.layers[il].ssm_out, y);
+            cb(cur, "mamba_out_proj", il);
+        }
+
+        // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+        cb(cur, "mamba_out", il);
+
+        return cur;
+    }
+};
+
 struct llm_build_arcee : public llm_graph_context {
-    llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -15553,7 +16478,7 @@ struct llm_build_arcee : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
@@ -15612,7 +16537,7 @@ struct llm_build_arcee : public llm_graph_context {
 };
 
 struct llm_build_hunyuan_moe : public llm_graph_context {
-    llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -15698,7 +16623,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
                         LLM_NORM_RMS, il);
                 cb(Qcur, "Qcur_norm", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
@@ -15773,7 +16698,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
 };
 
 struct llm_build_smollm3 : public llm_graph_context {
-    llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -15850,7 +16775,7 @@ struct llm_build_smollm3 : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, gf,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
                         Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
@@ -15912,7 +16837,7 @@ struct llm_build_smollm3 : public llm_graph_context {
 struct llm_build_lfm2 : public llm_graph_context {
     const llama_model & model;
 
-    llm_build_lfm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
+    llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
 
         ggml_tensor * cur = build_inp_embd(model.tok_embd);
         cb(cur, "model.embed_tokens", -1);
@@ -15927,8 +16852,8 @@ struct llm_build_lfm2 : public llm_graph_context {
             cb(cur, "model.layers.{}.operator_norm", il);
 
             cur = hparams.is_recurrent(il) ?
-                build_shortconv_block(gf, cur, inp_hybrid->get_recr(), il) :
-                build_attn_block(gf, cur, inp_pos, inp_hybrid->get_attn(), il) ;
+                build_shortconv_block(cur, inp_hybrid->get_recr(), il) :
+                build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il) ;
 
             if (il == n_layer - 1 && inp_out_ids) {
                 cur      = ggml_get_rows(ctx0,      cur, inp_out_ids);
@@ -15971,8 +16896,7 @@ struct llm_build_lfm2 : public llm_graph_context {
         return cur;
     }
 
-    ggml_tensor * build_attn_block(ggml_cgraph                     * gf,
-                                   ggml_tensor                     * cur,
+    ggml_tensor * build_attn_block(ggml_tensor                     * cur,
                                    ggml_tensor                     * inp_pos,
                                    llm_graph_input_attn_kv_unified * inp_attn,
                                    int                               il) const {
@@ -16009,7 +16933,7 @@ struct llm_build_lfm2 : public llm_graph_context {
                 ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
-        cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL,
+        cur = build_attn(inp_attn, model.layers[il].wo, NULL,
                 q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
 
         cb(cur, "model.layers.{}.self_attn.out_proj", il);
@@ -16017,11 +16941,22 @@ struct llm_build_lfm2 : public llm_graph_context {
         return cur;
     }
 
-    ggml_tensor * build_shortconv_block(ggml_cgraph        * gf,
-                                        ggml_tensor        * cur,
+    ggml_tensor * build_shortconv_block(ggml_tensor        * cur,
                                         llm_graph_input_rs * inp_recr,
                                         int                il) {
-        const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
+        const auto *   mctx_cur     = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
+        const uint32_t kv_head      = mctx_cur->get_head();
+        const int64_t  n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t  n_seqs       = ubatch.n_seqs;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs());
+        GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+        GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
+        const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
+
+        // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+        cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
 
         auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
         cb(bcx, "model.layers.{}.conv.in_proj", il);
@@ -16029,38 +16964,48 @@ struct llm_build_lfm2 : public llm_graph_context {
         constexpr auto n_chunks = 3;
         GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
         auto const chunk_size = bcx->ne[0] / n_chunks;
-        auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx));
-        auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx));
-        auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx));
+        auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 0*chunk_size*ggml_element_size(bcx));
+        auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 1*chunk_size*ggml_element_size(bcx));
+        auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 2*chunk_size*ggml_element_size(bcx));
 
         auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
 
-        // read conv state directly, with build_rs generation is slower
-        ggml_tensor * conv_state = mctx_cur->get_r_l(il);
-        const int64_t n_seqs  = ubatch.n_seqs;
-        ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
-        conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
+        // read conv state
+        auto * conv_state = mctx_cur->get_r_l(il);
+        auto * conv_rs    = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
+        auto * conv       = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
 
         bx = ggml_concat(ctx0, conv, bx, 0);
         GGML_ASSERT(bx->ne[0] > conv->ne[0]);
 
-        auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
+        // last d_conv columns is a new conv state
+        auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx));
         GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
 
-        // write conv state
-        ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
+        // write new conv conv state
+        ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    new_conv,
+                    ggml_view_1d(
+                        ctx0,
+                        conv_state,
+                        ggml_nelements(new_conv),
+                        kv_head*d_conv*n_embd*ggml_element_size(new_conv)
+                        )
+                    )
+                );
 
         auto * conv_kernel = model.layers[il].shortconv.conv;
-        GGML_ASSERT(hparams.n_shortconv_l_cache > 0);
-
-        // construct ssm_conv op
-        ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
+        auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
         cb(conv_out, "model.layers.{}.conv.conv", il);
 
         auto * y = ggml_mul(ctx0, c, conv_out);
-
         y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
         cb(y, "model.layers.{}.conv.out_proj", il);
+        // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+        y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
 
         return y;
     }
@@ -16078,6 +17023,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_WAVTOKENIZER_DEC:
+        case LLM_ARCH_DREAM:
             {
                 res = nullptr;
             } break;
@@ -16118,7 +17064,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                 } else {
                     const auto padding = llama_kv_cache_unified::get_padding(cparams);
 
-                    cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
+                    uint32_t n_ctx_per_stream = cparams.n_ctx;
+
+                    if (!cparams.kv_unified) {
+                        n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
+                        n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
+
+                        cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
+                    } else {
+                        n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
+
+                        cparams.n_ctx = n_ctx_per_stream;
+                    }
 
                     LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
 
@@ -16132,7 +17089,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 !cparams.flash_attn,
                                 cparams.offload_kqv,
                                 params.swa_full,
-                                cparams.n_ctx,
+                                cparams.kv_unified,
+                                n_ctx_per_stream,
                                 cparams.n_seq_max,
                                 cparams.n_ubatch,
                                 padding);
@@ -16146,7 +17104,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 params.type_v,
                                 !cparams.flash_attn,
                                 cparams.offload_kqv,
-                                cparams.n_ctx,
+                                cparams.kv_unified,
+                                n_ctx_per_stream,
                                 cparams.n_seq_max,
                                 padding,
                                 hparams.n_swa,
@@ -16159,227 +17118,233 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
     return res;
 }
 
-llm_graph_result_ptr llama_model::build_graph(
-        const llm_graph_params & params,
-                   ggml_cgraph * gf,
-                llm_graph_type   type) const {
+ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
     std::unique_ptr<llm_graph_context> llm;
 
     switch (arch) {
         case LLM_ARCH_LLAMA:
             {
-                llm = std::make_unique<llm_build_llama>(*this, params, gf);
+                llm = std::make_unique<llm_build_llama>(*this, params);
             } break;
         case LLM_ARCH_LLAMA4:
             {
-                llm = std::make_unique<llm_build_llama_iswa>(*this, params, gf);
+                llm = std::make_unique<llm_build_llama_iswa>(*this, params);
             } break;
         case LLM_ARCH_DECI:
             {
-                llm = std::make_unique<llm_build_deci>(*this, params, gf);
+                llm = std::make_unique<llm_build_deci>(*this, params);
             } break;
         case LLM_ARCH_BAICHUAN:
             {
-                llm = std::make_unique<llm_build_baichuan>(*this, params, gf);
+                llm = std::make_unique<llm_build_baichuan>(*this, params);
             } break;
         case LLM_ARCH_FALCON:
             {
-                llm = std::make_unique<llm_build_falcon>(*this, params, gf);
+                llm = std::make_unique<llm_build_falcon>(*this, params);
             } break;
         case LLM_ARCH_GROK:
             {
-                llm = std::make_unique<llm_build_grok>(*this, params, gf);
+                llm = std::make_unique<llm_build_grok>(*this, params);
             } break;
         case LLM_ARCH_STARCODER:
             {
-                llm = std::make_unique<llm_build_starcoder>(*this, params, gf);
+                llm = std::make_unique<llm_build_starcoder>(*this, params);
             } break;
         case LLM_ARCH_REFACT:
             {
-                llm = std::make_unique<llm_build_refact>(*this, params, gf);
+                llm = std::make_unique<llm_build_refact>(*this, params);
             } break;
         case LLM_ARCH_BERT:
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
             {
-                llm = std::make_unique<llm_build_bert>(*this, params, gf);
+                llm = std::make_unique<llm_build_bert>(*this, params);
             } break;
         case LLM_ARCH_NEO_BERT:
             {
-                llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
+                llm = std::make_unique<llm_build_neo_bert>(*this, params);
             } break;
         case LLM_ARCH_BLOOM:
             {
-                llm = std::make_unique<llm_build_bloom>(*this, params, gf);
+                llm = std::make_unique<llm_build_bloom>(*this, params);
             } break;
         case LLM_ARCH_MPT:
             {
-                llm = std::make_unique<llm_build_mpt>(*this, params, gf);
+                llm = std::make_unique<llm_build_mpt>(*this, params);
             } break;
         case LLM_ARCH_STABLELM:
             {
-                llm = std::make_unique<llm_build_stablelm>(*this, params, gf);
+                llm = std::make_unique<llm_build_stablelm>(*this, params);
             } break;
         case LLM_ARCH_QWEN:
             {
-                llm = std::make_unique<llm_build_qwen>(*this, params, gf);
+                llm = std::make_unique<llm_build_qwen>(*this, params);
             } break;
         case LLM_ARCH_QWEN2:
             {
-                llm = std::make_unique<llm_build_qwen2>(*this, params, gf);
+                llm = std::make_unique<llm_build_qwen2>(*this, params);
             } break;
+        case LLM_ARCH_DREAM:
+            {
+                llm = std::make_unique<llm_build_dream>(*this, params);
+            }
+            break;
         case LLM_ARCH_QWEN2VL:
             {
-                llm = std::make_unique<llm_build_qwen2vl>(*this, params, gf);
+                llm = std::make_unique<llm_build_qwen2vl>(*this, params);
             } break;
         case LLM_ARCH_QWEN2MOE:
             {
-                llm = std::make_unique<llm_build_qwen2moe>(*this, params, gf);
+                llm = std::make_unique<llm_build_qwen2moe>(*this, params);
             } break;
         case LLM_ARCH_QWEN3:
             {
-                llm = std::make_unique<llm_build_qwen3>(*this, params, gf);
+                llm = std::make_unique<llm_build_qwen3>(*this, params);
             } break;
         case LLM_ARCH_QWEN3MOE:
             {
-                llm = std::make_unique<llm_build_qwen3moe>(*this, params, gf);
+                llm = std::make_unique<llm_build_qwen3moe>(*this, params);
             } break;
         case LLM_ARCH_PHI2:
             {
-                llm = std::make_unique<llm_build_phi2>(*this, params, gf);
+                llm = std::make_unique<llm_build_phi2>(*this, params);
             } break;
         case LLM_ARCH_PHI3:
         case LLM_ARCH_PHIMOE:
             {
                 if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
-                    llm = std::make_unique<llm_build_phi3<true>> (*this, params, gf);
+                    llm = std::make_unique<llm_build_phi3<true>> (*this, params);
                 } else {
-                    llm = std::make_unique<llm_build_phi3<false>>(*this, params, gf);
+                    llm = std::make_unique<llm_build_phi3<false>>(*this, params);
                 }
             } break;
         case LLM_ARCH_PLAMO:
             {
-                llm = std::make_unique<llm_build_plamo>(*this, params, gf);
+                llm = std::make_unique<llm_build_plamo>(*this, params);
+            } break;
+        case LLM_ARCH_PLAMO2:
+            {
+                llm = std::make_unique<llm_build_plamo2>(*this, params);
             } break;
         case LLM_ARCH_GPT2:
             {
-                llm = std::make_unique<llm_build_gpt2>(*this, params, gf);
+                llm = std::make_unique<llm_build_gpt2>(*this, params);
             } break;
         case LLM_ARCH_CODESHELL:
             {
-                llm = std::make_unique<llm_build_codeshell>(*this, params, gf);
+                llm = std::make_unique<llm_build_codeshell>(*this, params);
             } break;
         case LLM_ARCH_ORION:
             {
-                llm = std::make_unique<llm_build_orion>(*this, params, gf);
+                llm = std::make_unique<llm_build_orion>(*this, params);
             } break;
         case LLM_ARCH_INTERNLM2:
             {
-                llm = std::make_unique<llm_build_internlm2>(*this, params, gf);
+                llm = std::make_unique<llm_build_internlm2>(*this, params);
             } break;
         case LLM_ARCH_MINICPM3:
             {
-                llm = std::make_unique<llm_build_minicpm3>(*this, params, gf);
+                llm = std::make_unique<llm_build_minicpm3>(*this, params);
             } break;
         case LLM_ARCH_GEMMA:
             {
-                llm = std::make_unique<llm_build_gemma>(*this, params, gf);
+                llm = std::make_unique<llm_build_gemma>(*this, params);
             } break;
         case LLM_ARCH_GEMMA2:
             {
-                llm = std::make_unique<llm_build_gemma2_iswa>(*this, params, gf);
+                llm = std::make_unique<llm_build_gemma2_iswa>(*this, params);
             } break;
         case LLM_ARCH_GEMMA3:
             {
-                llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
+                llm = std::make_unique<llm_build_gemma3_iswa>(*this, params);
             } break;
         case LLM_ARCH_GEMMA3N:
             {
-                llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
+                llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
             } break;
         case LLM_ARCH_STARCODER2:
             {
-                llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
+                llm = std::make_unique<llm_build_starcoder2>(*this, params);
             } break;
         case LLM_ARCH_MAMBA:
         case LLM_ARCH_MAMBA2:
             {
-                llm = std::make_unique<llm_build_mamba>(*this, params, gf);
+                llm = std::make_unique<llm_build_mamba>(*this, params);
             } break;
         case LLM_ARCH_JAMBA:
             {
-                llm = std::make_unique<llm_build_jamba>(*this, params, gf);
+                llm = std::make_unique<llm_build_jamba>(*this, params);
             } break;
         case LLM_ARCH_XVERSE:
             {
-                llm = std::make_unique<llm_build_xverse>(*this, params, gf);
+                llm = std::make_unique<llm_build_xverse>(*this, params);
             } break;
         case LLM_ARCH_COMMAND_R:
             {
-                llm = std::make_unique<llm_build_command_r>(*this, params, gf);
+                llm = std::make_unique<llm_build_command_r>(*this, params);
             } break;
         case LLM_ARCH_COHERE2:
             {
-                llm = std::make_unique<llm_build_cohere2_iswa>(*this, params, gf);
+                llm = std::make_unique<llm_build_cohere2_iswa>(*this, params);
             } break;
         case LLM_ARCH_DBRX:
             {
-                llm = std::make_unique<llm_build_dbrx>(*this, params, gf);
+                llm = std::make_unique<llm_build_dbrx>(*this, params);
             } break;
         case LLM_ARCH_OLMO:
             {
-                llm = std::make_unique<llm_build_olmo>(*this, params, gf);
+                llm = std::make_unique<llm_build_olmo>(*this, params);
             } break;
         case LLM_ARCH_OLMO2:
             {
-                llm = std::make_unique<llm_build_olmo2>(*this, params, gf);
+                llm = std::make_unique<llm_build_olmo2>(*this, params);
             } break;
         case LLM_ARCH_OLMOE:
             {
-                llm = std::make_unique<llm_build_olmoe>(*this, params, gf);
+                llm = std::make_unique<llm_build_olmoe>(*this, params);
             } break;
         case LLM_ARCH_OPENELM:
             {
-                llm = std::make_unique<llm_build_openelm>(*this, params, gf);
+                llm = std::make_unique<llm_build_openelm>(*this, params);
             } break;
         case LLM_ARCH_GPTNEOX:
             {
-                llm = std::make_unique<llm_build_gptneox>(*this, params, gf);
+                llm = std::make_unique<llm_build_gptneox>(*this, params);
             } break;
         case LLM_ARCH_ARCTIC:
             {
-                llm = std::make_unique<llm_build_arctic>(*this, params, gf);
+                llm = std::make_unique<llm_build_arctic>(*this, params);
             } break;
         case LLM_ARCH_DEEPSEEK:
             {
-                llm = std::make_unique<llm_build_deepseek>(*this, params, gf);
+                llm = std::make_unique<llm_build_deepseek>(*this, params);
             } break;
         case LLM_ARCH_DEEPSEEK2:
             {
-                llm = std::make_unique<llm_build_deepseek2>(*this, params, gf);
+                llm = std::make_unique<llm_build_deepseek2>(*this, params);
             } break;
         case LLM_ARCH_CHATGLM:
             {
-                llm = std::make_unique<llm_build_chatglm>(*this, params, gf);
+                llm = std::make_unique<llm_build_chatglm>(*this, params);
             } break;
         case LLM_ARCH_GLM4:
             {
-                llm = std::make_unique<llm_build_glm4>(*this, params, gf);
+                llm = std::make_unique<llm_build_glm4>(*this, params);
             } break;
         case LLM_ARCH_BITNET:
             {
-                llm = std::make_unique<llm_build_bitnet>(*this, params, gf);
+                llm = std::make_unique<llm_build_bitnet>(*this, params);
             } break;
         case LLM_ARCH_T5:
             {
-                switch (type) {
+                switch (params.gtype) {
                     case LLM_GRAPH_TYPE_ENCODER:
-                        llm = std::make_unique<llm_build_t5_enc>(*this, params, gf);
+                        llm = std::make_unique<llm_build_t5_enc>(*this, params);
                         break;
                     case LLM_GRAPH_TYPE_DEFAULT:
                     case LLM_GRAPH_TYPE_DECODER:
-                        llm = std::make_unique<llm_build_t5_dec>(*this, params, gf);
+                        llm = std::make_unique<llm_build_t5_dec>(*this, params);
                         break;
                     default:
                         GGML_ABORT("invalid graph type");
@@ -16387,99 +17352,111 @@ llm_graph_result_ptr llama_model::build_graph(
             } break;
         case LLM_ARCH_T5ENCODER:
             {
-                llm = std::make_unique<llm_build_t5_enc>(*this, params, gf);
+                llm = std::make_unique<llm_build_t5_enc>(*this, params);
             }
             break;
         case LLM_ARCH_JAIS:
             {
-                llm = std::make_unique<llm_build_jais>(*this, params, gf);
+                llm = std::make_unique<llm_build_jais>(*this, params);
             } break;
         case LLM_ARCH_NEMOTRON:
             {
-                llm = std::make_unique<llm_build_nemotron>(*this, params, gf);
+                llm = std::make_unique<llm_build_nemotron>(*this, params);
             } break;
         case LLM_ARCH_EXAONE:
             {
-                llm = std::make_unique<llm_build_exaone>(*this, params, gf);
+                llm = std::make_unique<llm_build_exaone>(*this, params);
+            } break;
+        case LLM_ARCH_EXAONE4:
+            {
+                if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
+                    llm = std::make_unique<llm_build_exaone4<true>>(*this, params);
+                } else {
+                    llm = std::make_unique<llm_build_exaone4<false>>(*this, params);
+                }
             } break;
         case LLM_ARCH_RWKV6:
             {
-                llm = std::make_unique<llm_build_rwkv6>(*this, params, gf);
+                llm = std::make_unique<llm_build_rwkv6>(*this, params);
             } break;
         case LLM_ARCH_RWKV6QWEN2:
             {
-                llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params, gf);
+                llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params);
             } break;
         case LLM_ARCH_RWKV7:
             {
-                llm = std::make_unique<llm_build_rwkv7>(*this, params, gf);
+                llm = std::make_unique<llm_build_rwkv7>(*this, params);
             } break;
         case LLM_ARCH_ARWKV7:
             {
-                llm = std::make_unique<llm_build_arwkv7>(*this, params, gf);
+                llm = std::make_unique<llm_build_arwkv7>(*this, params);
             } break;
         case LLM_ARCH_GRANITE:
         case LLM_ARCH_GRANITE_MOE:
         case LLM_ARCH_MINICPM:
             {
-                llm = std::make_unique<llm_build_granite>(*this, params, gf);
+                llm = std::make_unique<llm_build_granite>(*this, params);
             } break;
         case LLM_ARCH_GRANITE_HYBRID:
             {
-                llm = std::make_unique<llm_build_granite_hybrid>(*this, params, gf);
+                llm = std::make_unique<llm_build_granite_hybrid>(*this, params);
             } break;
         case LLM_ARCH_CHAMELEON:
             {
-                llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
+                llm = std::make_unique<llm_build_chameleon>(*this, params);
             } break;
         case LLM_ARCH_WAVTOKENIZER_DEC:
             {
-                llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
+                llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params);
             } break;
         case LLM_ARCH_PLM:
             {
-                llm = std::make_unique<llm_build_plm>(*this, params, gf);
+                llm = std::make_unique<llm_build_plm>(*this, params);
             } break;
         case LLM_ARCH_BAILINGMOE:
             {
-                llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
+                llm = std::make_unique<llm_build_bailingmoe>(*this, params);
             } break;
         case LLM_ARCH_DOTS1:
             {
-                llm = std::make_unique<llm_build_dots1>(*this, params, gf);
+                llm = std::make_unique<llm_build_dots1>(*this, params);
             } break;
         case LLM_ARCH_ARCEE:
             {
-                llm = std::make_unique<llm_build_arcee>(*this, params, gf);
+                llm = std::make_unique<llm_build_arcee>(*this, params);
             } break;
         case LLM_ARCH_ERNIE4_5:
             {
-                llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
+                llm = std::make_unique<llm_build_ernie4_5>(*this, params);
+            } break;
+        case LLM_ARCH_ERNIE4_5_MOE:
+            {
+                llm = std::make_unique<llm_build_ernie4_5_moe>(*this, params);
             } break;
         case LLM_ARCH_HUNYUAN_MOE:
             {
-                llm = std::make_unique<llm_build_hunyuan_moe>(*this, params, gf);
+                llm = std::make_unique<llm_build_hunyuan_moe>(*this, params);
             } break;
         case LLM_ARCH_SMOLLM3:
             {
-                llm = std::make_unique<llm_build_smollm3>(*this, params, gf);
+                llm = std::make_unique<llm_build_smollm3>(*this, params);
             } break;
         case LLM_ARCH_FALCON_H1:
             {
-                llm = std::make_unique<llm_build_falcon_h1>(*this, params, gf);
+                llm = std::make_unique<llm_build_falcon_h1>(*this, params);
             } break;
         case LLM_ARCH_LFM2:
             {
-                llm = std::make_unique<llm_build_lfm2>(*this, params, gf);
+                llm = std::make_unique<llm_build_lfm2>(*this, params);
             } break;
         default:
             GGML_ABORT("fatal error");
     }
 
     // add on pooling layer
-    llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b);
+    llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
 
-    return std::move(llm->res);
+    return llm->res->get_gf();
 }
 
 //
@@ -16628,6 +17605,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_SMOLLM3:
         case LLM_ARCH_ARCEE:
         case LLM_ARCH_ERNIE4_5:
+        case LLM_ARCH_ERNIE4_5_MOE:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
@@ -16642,6 +17620,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_BITNET:
         case LLM_ARCH_QWEN:
         case LLM_ARCH_QWEN2:
+        case LLM_ARCH_DREAM:
         case LLM_ARCH_QWEN2MOE:
         case LLM_ARCH_QWEN3:
         case LLM_ARCH_QWEN3MOE:
@@ -16651,6 +17630,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_PHI3:
         case LLM_ARCH_PHIMOE:
         case LLM_ARCH_PLAMO:
+        case LLM_ARCH_PLAMO2:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_GEMMA3:
@@ -16662,6 +17642,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_ORION:
         case LLM_ARCH_NEMOTRON:
         case LLM_ARCH_EXAONE:
+        case LLM_ARCH_EXAONE4:
         case LLM_ARCH_MINICPM3:
         case LLM_ARCH_DOTS1:
         case LLM_ARCH_HUNYUAN_MOE:
index 027a7f0c3e2c694ae55171efd1524df99488a118..094e23808a81392674c1122369cdef9c3651dc57 100644 (file)
@@ -99,8 +99,10 @@ enum llm_type {
     LLM_TYPE_17B_16E, // llama4 Scout
     LLM_TYPE_17B_128E, // llama4 Maverick
     LLM_TYPE_A13B,
+    LLM_TYPE_21B_A3B, // Ernie MoE small
     LLM_TYPE_30B_A3B,
     LLM_TYPE_235B_A22B,
+    LLM_TYPE_300B_A47B, // Ernie MoE big
     LLM_TYPE_E2B,
     LLM_TYPE_E4B,
 };
@@ -452,10 +454,7 @@ struct llama_model {
     llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
 
     // TODO: move this to new llm_arch_model_i interface
-    llm_graph_result_ptr build_graph(
-            const llm_graph_params & params,
-                       ggml_cgraph * gf,
-                    llm_graph_type   type) const;
+    ggml_cgraph * build_graph(const llm_graph_params & params) const;
 
 private:
     struct impl;
index 4dbd1e309919a52c98b25fa93ac1a514c8641d03..a00af7a1d1758855ec5f8febba2c4a0015cb7710 100644 (file)
@@ -884,8 +884,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
                         if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
                             if  (qtype != new_type) {
                                 LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
-                                new_type = qtype;
-                                break; // if two or more types are specified for the tensor, first match wins
+                                new_type = qtype; // if two or more types are specified for the same tensor, the last match wins
                             }
                         }
                     }
index e0e578d6394d822a4c927b241ff71d16c7140aa2..e8bae645088dded8d15be3b0b58ecafec53c8c19 100644 (file)
@@ -11,6 +11,7 @@
 #include <cassert>
 #include <cctype>
 #include <cfloat>
+#include <cmath>
 #include <cstdarg>
 #include <cstring>
 #include <forward_list>
@@ -404,6 +405,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
                     "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_KIMI_K2:
+                regex_exprs = {
+                    // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp
+                    // The custom handler implements all K2 patterns with proper Han character exclusion
+                    "\\p{Han}+",
+                };
+                break;
             case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
                 regex_exprs = {
                     "\\p{N}+",
@@ -1196,6 +1204,284 @@ private:
     const llm_tokenizer_rwkv & tokenizer;
 };
 
+struct llm_tokenizer_plamo2 : llm_tokenizer {
+    llm_tokenizer_plamo2(const llama_vocab & vocab) {
+        build(vocab);
+    }
+
+    void build(const llama_vocab & vocab) {
+        // Reset internal structures
+        tokens_.clear();
+        bytes_.assign(256, 0);
+        to_suffix_id_.clear();
+        table_.clear();
+
+        // Build token list and byte mapping
+        std::unordered_map<std::string, float> suffix_to_score;
+        std::unordered_map<std::string, llama_token> token_to_id;
+
+        for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) {
+            const auto & entry = vocab.get_token_data(token_id);
+            tokens_.push_back(entry.text);
+            token_to_id[entry.text] = static_cast<llama_token>(token_id);
+
+            // Handle byte tokens
+            if (vocab.is_byte(token_id)) {
+                if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') {
+                    std::string hex_str = entry.text.substr(3, 2);
+                    int byte_val = std::stoi(hex_str, nullptr, 16);
+                    bytes_[byte_val] = static_cast<llama_token>(token_id);
+                }
+                continue;
+            }
+
+            // Add token and all its suffixes to suffix_to_score
+            suffix_to_score[entry.text] = entry.score;
+
+            // Extract suffixes character by character (UTF-8 aware)
+            std::vector<uint32_t> cpts = unicode_cpts_from_utf8(entry.text);
+            for (size_t i = 1; i < cpts.size(); ++i) {
+                std::string suffix;
+                for (size_t j = i; j < cpts.size(); ++j) {
+                    suffix += unicode_cpt_to_utf8(cpts[j]);
+                }
+                if (suffix_to_score.find(suffix) == suffix_to_score.end()) {
+                    suffix_to_score[suffix] = std::numeric_limits<float>::quiet_NaN();
+                }
+            }
+        }
+
+        // Check that all byte tokens are set
+        for (int i = 0; i < 256; ++i) {
+            if (bytes_[i] == 0) {
+                throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set");
+            }
+        }
+
+        // Build suffix list in lexicographical order of reversed strings
+        std::vector<std::string> suffixes;
+        for (const auto & pair : suffix_to_score) {
+            suffixes.push_back(pair.first);
+        }
+        suffixes.push_back("");  // Empty suffix
+
+        std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) {
+            std::string rev_a(a.rbegin(), a.rend());
+            std::string rev_b(b.rbegin(), b.rend());
+            return rev_a < rev_b;
+        });
+
+        // Build suffix_to_id and to_suffix_id_
+        std::unordered_map<std::string, int32_t> suffix_to_id;
+        int32_t num_pieces = 0;
+
+        for (const auto & suffix : suffixes) {
+            suffix_to_id[suffix] = num_pieces;
+            if (!suffix.empty()) {
+                std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
+
+                std::string remaining;
+                for (size_t i = 1; i < cpts.size(); ++i) {
+                    remaining += unicode_cpt_to_utf8(cpts[i]);
+                }
+
+                int64_t piece_code = (static_cast<int64_t>(cpts[0]) << 32) | suffix_to_id[remaining];
+                to_suffix_id_[piece_code] = num_pieces;
+
+                // Count number of pieces for this suffix
+                int32_t pieces_for_suffix = 1; // sentinel row
+                for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
+                    std::string piece;
+                    for (int32_t i = 0; i < piece_length; ++i) {
+                        piece += unicode_cpt_to_utf8(cpts[i]);
+                    }
+                    if (suffix_to_score.find(piece) != suffix_to_score.end()) {
+                        pieces_for_suffix++;
+                    }
+                }
+                num_pieces += pieces_for_suffix;
+            } else {
+                num_pieces++;  // Empty suffix contributes one piece (sentinel row)
+            }
+        }
+
+        // Build flattened table
+        table_.resize(num_pieces, std::vector<int32_t>(4, 0));
+        int32_t table_idx = 0;
+
+        for (const auto & suffix : suffixes) {
+            // Add all prefixes of the suffix to the table (in decreasing order of length)
+            std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
+            for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
+                std::string piece;
+                for (int32_t i = 0; i < piece_length; ++i) {
+                    piece += unicode_cpt_to_utf8(cpts[i]);
+                }
+
+                auto score_it = suffix_to_score.find(piece);
+                if (score_it == suffix_to_score.end()) {
+                    continue;
+                }
+
+                table_[table_idx][TABLE_PIECE_LENGTH] = piece_length;
+                auto token_it = token_to_id.find(piece);
+                table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1;
+
+                float score = score_it->second;
+                table_[table_idx][TABLE_SCORE] = std::isfinite(score) ?
+                    static_cast<int32_t>(std::round(score * 1e4)) : INVALID_SCORE;
+                table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece];
+
+                table_idx++;
+            }
+
+            // Add sentinel row
+            table_[table_idx][TABLE_PIECE_LENGTH] = 1;
+            table_[table_idx][TABLE_TOKEN_ID] = -1;
+            table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE;
+            table_idx++;
+        }
+    }
+
+    std::vector<llama_token> encode(const std::string & text) const {
+        std::vector<uint32_t> unicode_data = unicode_cpts_from_utf8(text);
+        // Skip the first code point if it is a BOM (Byte Order Mark)
+        if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) {
+            unicode_data.erase(unicode_data.begin());
+        }
+
+        if (unicode_data.empty()) {
+            return {};
+        }
+
+        const size_t data_len = unicode_data.size();
+
+        // Initialize scores array (dynamic programming)
+        std::vector<int64_t> scores(data_len + 1, static_cast<int64_t>(1) << 60);
+        scores[data_len] = 0;
+
+        // Path array to track best tokenization
+        std::vector<std::vector<int32_t>> path(data_len + 1, std::vector<int32_t>(3, 0));
+
+        int32_t suffix_id = 0;
+
+        // Process from end to beginning
+        for (int i = static_cast<int>(data_len) - 1; i >= 0; --i) {
+            uint32_t c = unicode_data[i];
+
+            // Find next suffix ID
+            for (size_t p = suffix_id; p < table_.size(); ++p) {
+                int64_t piece_code = (static_cast<int64_t>(c) << 32) | table_[p][TABLE_PIECE_ID];
+                auto it = to_suffix_id_.find(piece_code);
+                suffix_id = (it != to_suffix_id_.end()) ? it->second : 0;
+
+                if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) {
+                    break;
+                }
+            }
+
+            // Update best path
+            for (size_t p = suffix_id; p < table_.size(); ++p) {
+                int32_t score = table_[p][TABLE_SCORE];
+                if (score > INVALID_SCORE) {
+                    int32_t piece_length = table_[p][TABLE_PIECE_LENGTH];
+                    int64_t s = scores[i + piece_length] - score;
+
+                    if (s < scores[i]) {
+                        scores[i] = s;
+                        path[i][PATH_TOKEN_LENGTH] = piece_length;
+                        path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID];
+                        path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1;
+
+                        if (score == UNKNOWN_SCORE) {
+                            // Add UTF-8 byte count
+                            path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
+                        }
+                    }
+                }
+
+                if (score == UNKNOWN_SCORE) {
+                    break;
+                }
+            }
+        }
+
+        // Decode the best path
+        std::vector<llama_token> token_ids;
+        token_ids.reserve(path[0][PATH_NUM_TOKENS]);
+
+        int pos = 0;
+        while (pos < static_cast<int>(data_len)) {
+            if (path[pos][PATH_TOKEN_ID] >= 0) {
+                token_ids.push_back(path[pos][PATH_TOKEN_ID]);
+            } else {
+                // Fall back to byte tokens
+                uint32_t c = unicode_data[pos];
+                int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
+
+                for (int i = 0; i < s; ++i) {
+                    uint8_t b;
+                    if (s == 1) {
+                        b = c;
+                    } else {
+                        if (i == 0) {
+                            b = (0xF00 >> s) & 0xFF;
+                        } else {
+                            b = 0x80;
+                        }
+                    }
+                    token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]);
+                }
+            }
+
+            assert(path[pos][PATH_TOKEN_LENGTH] > 0);
+            pos += path[pos][PATH_TOKEN_LENGTH];
+        }
+
+        return token_ids;
+    }
+private:
+    // Constants for table structure
+    static constexpr int32_t TABLE_PIECE_LENGTH = 0;
+    static constexpr int32_t TABLE_TOKEN_ID     = 1;
+    static constexpr int32_t TABLE_SCORE        = 2;
+    static constexpr int32_t TABLE_PIECE_ID     = 3;
+
+    // Constants for path array
+    static constexpr int32_t PATH_TOKEN_LENGTH  = 0;
+    static constexpr int32_t PATH_TOKEN_ID      = 1;
+    static constexpr int32_t PATH_NUM_TOKENS    = 2;
+
+    // Score constants
+    static constexpr int32_t INVALID_SCORE = -20000000;
+    static constexpr int32_t UNKNOWN_SCORE = -10000000;
+
+    // List of tokens in the vocabulary
+    std::vector<std::string> tokens_;
+
+    // Mapping from byte code point to token ID (for byte fallback)
+    std::vector<llama_token> bytes_;
+
+    // Mapping from piece code to suffix ID
+    std::unordered_map<int64_t, int32_t> to_suffix_id_;
+
+    // Flattened table representing the Trie structure
+    // Each row contains: [piece_length, token_id, score, piece_id]
+    std::vector<std::vector<int32_t>> table_;
+};
+
+struct llm_tokenizer_plamo2_session {
+    llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {}
+
+    void tokenize(const std::string & text, std::vector<llama_token> & output) {
+        std::vector<llama_token> tokens = tokenizer.encode(text);
+        output.insert(output.end(), tokens.begin(), tokens.end());
+    }
+
+private:
+    const llm_tokenizer_plamo2 & tokenizer;
+};
+
 //
 // impl
 //
@@ -1499,6 +1785,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             special_unk_id = LLAMA_TOKEN_NULL;
             special_sep_id = LLAMA_TOKEN_NULL;
             special_pad_id = LLAMA_TOKEN_NULL;
+        } else if (tokenizer_model == "plamo2") {
+            type = LLAMA_VOCAB_TYPE_PLAMO2;
+
+            // PLaMo-2 default special tokens (these will be overridden by model config)
+            special_bos_id = 1;  // <|plamo:bos|>
+            special_eos_id = 2;  // <|plamo:eos|>
+            special_unk_id = 0;  // <|plamo:unk|>
+            special_sep_id = LLAMA_TOKEN_NULL;
+            special_pad_id = 3;  // <|plamo:pad|>
+            special_mask_id = LLAMA_TOKEN_NULL;
         } else {
             throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
         }
@@ -1629,6 +1925,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             } else if (
                 tokenizer_pre == "exaone") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE;
+            } else if (
+                tokenizer_pre == "exaone4") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
             } else if (
                 tokenizer_pre == "chameleon") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
@@ -1665,6 +1964,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 tokenizer_pre == "hunyuan") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
                 clean_spaces = false;
+            } else if (
+                tokenizer_pre == "kimi-k2") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
+                clean_spaces = false;
             } else {
                 throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
             }
@@ -2145,13 +2448,14 @@ enum llama_vocab_type llama_vocab::impl::get_type() const {
 
 std::string llama_vocab::impl::type_name() const{
     switch (type) {
-        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
-        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
-        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
-        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
-        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
-        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
-        default:                    return "unknown";
+        case LLAMA_VOCAB_TYPE_NONE:   return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:    return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:    return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:    return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:    return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV:   return "RWKV";
+        case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2";
+        default:                      return "unknown";
     }
 }
 
@@ -2234,6 +2538,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
         case LLAMA_VOCAB_TYPE_RWKV:
             tokenizer = std::make_unique<llm_tokenizer_rwkv>(vocab);
             break;
+        case LLAMA_VOCAB_TYPE_PLAMO2:
+            tokenizer = std::make_unique<llm_tokenizer_plamo2>(vocab);
+            break;
         default:
             GGML_ABORT("unsupported vocab type");
     }
@@ -2566,6 +2873,23 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
 
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_PLAMO2:
+            {
+                llm_tokenizer_plamo2_session session(*static_cast<const llm_tokenizer_plamo2 *>(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
 #endif
@@ -2664,6 +2988,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
                 memcpy(buf, result.data(), result.size());
                 return (int)result.size();
             }
+            case LLAMA_VOCAB_TYPE_PLAMO2: {
+                // PLaMo-2 uses similar token handling as BPE/SPM
+                if (vocab.is_byte(token)) {
+                    // Handle byte tokens like <0xXX>
+                    if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') {
+                        int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16);
+                        if (length < 1) {
+                            return -1;
+                        }
+                        buf[0] = static_cast<char>(hex_val);
+                        return 1;
+                    }
+                }
+
+                // Normal token - just copy the text
+                std::string result = token_text;
+                return _try_copy(result.data(), result.size());
+            }
             default:
                 GGML_ABORT("fatal error");
         }
@@ -2908,6 +3250,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const {
         case LLAMA_VOCAB_TYPE_BPE: {
             return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
         }
+        case LLAMA_VOCAB_TYPE_PLAMO2: {
+            // PLaMo-2 uses byte tokens in format <0xXX>
+            char hex_str[8];
+            snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch);
+            return pimpl->token_to_id.at(hex_str);
+        }
         default:
             GGML_ABORT("fatal error");
     }
@@ -3009,6 +3357,10 @@ llama_token llama_vocab::token_fim_sep() const {
     return pimpl->special_fim_sep_id;
 }
 
+llama_token llama_vocab::token_mask() const {
+    return pimpl->special_mask_id;
+}
+
 bool llama_vocab::get_add_space_prefix() const {
     return pimpl->add_space_prefix;
 }
@@ -3249,6 +3601,10 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
     return vocab->token_fim_sep();
 }
 
+llama_token llama_vocab_mask(const struct llama_vocab* vocab) {
+    return vocab->token_mask();
+}
+
 // deprecated
 const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
     return llama_vocab_get_text(vocab, token);
@@ -3385,4 +3741,3 @@ int32_t llama_detokenize(
                         bool   unparse_special) {
     return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
 }
-
index 46a1ccecb51fcf61a8beeda7871579d940b3e154..842b129e86171dd9f1def51e464e1cac6ab06828 100644 (file)
@@ -45,6 +45,7 @@ enum llama_vocab_pre_type {
     LLAMA_VOCAB_PRE_TYPE_PIXTRAL        = 34,
     LLAMA_VOCAB_PRE_TYPE_SEED_CODER     = 35,
     LLAMA_VOCAB_PRE_TYPE_HUNYUAN        = 36,
+    LLAMA_VOCAB_PRE_TYPE_KIMI_K2        = 37,
 };
 
 struct LLM_KV;
@@ -100,6 +101,7 @@ struct llama_vocab {
     llama_token token_sep() const;
     llama_token token_nl () const;
     llama_token token_pad() const;
+    llama_token token_mask() const;
 
     llama_token token_prefix() const;
     llama_token token_middle() const;
index f73b1ab65fe6fcad6ea7afa37ab443246fa10c3e..6f454a508a06c80bb92fab68f97f06e6cb95ccd5 100644 (file)
@@ -71,12 +71,13 @@ extern "C" {
     typedef int32_t llama_seq_id;
 
     enum llama_vocab_type {
-        LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
-        LLAMA_VOCAB_TYPE_SPM  = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
-        LLAMA_VOCAB_TYPE_BPE  = 2, // GPT-2 tokenizer based on byte-level BPE
-        LLAMA_VOCAB_TYPE_WPM  = 3, // BERT tokenizer based on WordPiece
-        LLAMA_VOCAB_TYPE_UGM  = 4, // T5 tokenizer based on Unigram
-        LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
+        LLAMA_VOCAB_TYPE_NONE   = 0, // For models without vocab
+        LLAMA_VOCAB_TYPE_SPM    = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
+        LLAMA_VOCAB_TYPE_BPE    = 2, // GPT-2 tokenizer based on byte-level BPE
+        LLAMA_VOCAB_TYPE_WPM    = 3, // BERT tokenizer based on WordPiece
+        LLAMA_VOCAB_TYPE_UGM    = 4, // T5 tokenizer based on Unigram
+        LLAMA_VOCAB_TYPE_RWKV   = 5, // RWKV tokenizer based on greedy tokenization
+        LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
     };
 
     enum llama_rope_type {
@@ -334,6 +335,9 @@ extern "C" {
         bool swa_full;    // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
                           // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
                           //       ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
+        bool kv_unified;  // use a unified buffer across the input sequences when computing the attention
+                          // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
+                          // ref: https://github.com/ggml-org/llama.cpp/pull/14363
     };
 
     // model quantization parameters
@@ -724,7 +728,7 @@ extern "C" {
     //   - lazily on next llama_decode()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    DEPRECATED(void llama_kv_self_seq_div(
+    DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
@@ -952,6 +956,7 @@ extern "C" {
     // in the order they have appeared in the batch.
     // Rows: number of tokens for which llama_batch.logits[i] != 0
     // Cols: n_vocab
+    // TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
     LLAMA_API float * llama_get_logits(struct llama_context * ctx);
 
     // Logits for the ith token. For positive indices, Equivalent to:
@@ -966,6 +971,7 @@ extern "C" {
     // in the order they have appeared in the batch.
     // shape: [n_outputs*n_embd]
     // Otherwise, returns NULL.
+    // TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
     LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
 
     // Get the embeddings for the ith token. For positive indices, Equivalent to:
@@ -1004,6 +1010,7 @@ extern "C" {
     LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
     LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
     LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
+    LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
 
     LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
     LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
@@ -1389,6 +1396,7 @@ extern "C" {
 
         int32_t n_p_eval;
         int32_t n_eval;
+        int32_t n_reused; // number of times a ggml compute graph had been reused
     };
 
     struct llama_perf_sampler_data {
index 43a4581b961fe9075d41405ae6d7b074ac982b22..65f366517158288b9cbc28dc473010ae396fc85d 100644 (file)
@@ -557,6 +557,178 @@ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, con
     return bpe_offsets;
 }
 
+// K2 system regex patterns (from tokenization_kimi.py):
+// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
+static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector<size_t> & offsets) {
+    std::vector<size_t> bpe_offsets;
+    bpe_offsets.reserve(offsets.size());
+
+    const auto cpts = unicode_cpts_from_utf8(text);
+
+    size_t start = 0;
+    for (auto offset : offsets) {
+        const size_t offset_ini = start;
+        const size_t offset_end = start + offset;
+        assert(offset_end <= cpts.size());
+        start = offset_end;
+
+        static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
+        };
+
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
+        };
+
+        size_t _prev_end = offset_ini;
+        auto _add_token = [&] (const size_t end) -> size_t {
+            assert(_prev_end <= end && end <= offset_end);
+            size_t len = end - _prev_end;
+            if (len > 0) {
+                bpe_offsets.push_back(len);
+            }
+            _prev_end = end;
+            return len;
+        };
+
+        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
+            const uint32_t cpt = _get_cpt(pos);
+            const auto flags = _get_flags(pos);
+
+            // Pattern 1: [\p{Han}]+ (Chinese characters)
+            if (unicode_cpt_is_han(cpt)) {
+                while (unicode_cpt_is_han(_get_cpt(pos))) {
+                    pos++;
+                }
+                _add_token(pos);
+                continue;
+            }
+
+            // Pattern 2 & 3: Letter words excluding Han characters with optional contractions
+            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
+            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
+            // Check if current char is a letter OR if current char could be a leading char and next char is a letter
+            bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) ||
+                                     (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) &&
+                                      _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1)));
+
+            if (is_letter_pattern) {
+                // Handle optional leading non-letter/non-number character
+                bool has_leading_char = false;
+                if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) {
+                    has_leading_char = true;
+                    pos++;
+                }
+
+                // Match letter sequence (excluding Han characters)
+                bool has_letters = false;
+                while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
+                    has_letters = true;
+                    pos++;
+                }
+
+                // Only proceed if we found letters (after potentially skipping leading char)
+                if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) {
+                    if (!has_letters) pos++; // consume the first letter if we didn't already
+
+                    // Continue consuming letters
+                    while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
+                        pos++;
+                    }
+
+                    // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
+                    if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) {
+                        uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
+                        if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
+                            pos += 2;
+                        } else if (pos + 2 < offset_end) {
+                            uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
+                            if ((cpt_next == 'r' && cpt_next_next == 'e') ||
+                                (cpt_next == 'v' && cpt_next_next == 'e') ||
+                                (cpt_next == 'l' && cpt_next_next == 'l')) {
+                                pos += 3;
+                            }
+                        }
+                    }
+
+                    _add_token(pos);
+                    continue;
+                } else if (has_leading_char) {
+                    // We consumed a leading char but found no letters, backtrack
+                    pos--;
+                }
+            }
+
+            // Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
+            if (flags.is_number) {
+                size_t ini = pos;
+                while (_get_flags(pos).is_number) {
+                    if (++pos - ini >= 3) {
+                        _add_token(pos);
+                        ini = pos;
+                    }
+                }
+                _add_token(pos);
+                continue;
+            }
+
+            // Pattern 5:  ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
+            auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
+            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
+                pos += (cpt == ' ');
+                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
+                    flags2 = _get_flags(++pos);
+                }
+                // Match optional [\r\n]*
+                uint32_t cpt2 = _get_cpt(pos);
+                while (cpt2 == '\r' || cpt2 == '\n') {
+                    cpt2 = _get_cpt(++pos);
+                }
+                _add_token(pos);
+                continue;
+            }
+
+            // Count whitespace characters
+            size_t num_whitespaces = 0;
+            size_t last_end_r_or_n = 0;
+            while (_get_flags(pos + num_whitespaces).is_whitespace) {
+                uint32_t cpt2 = _get_cpt(pos + num_whitespaces);
+                if (cpt2 == '\r' || cpt2 == '\n') {
+                    last_end_r_or_n = pos + num_whitespaces + 1;
+                }
+                num_whitespaces++;
+            }
+
+            // Pattern 6: \s*[\r\n]+ (whitespace with newlines)
+            if (last_end_r_or_n > 0) {
+                pos = last_end_r_or_n;
+                _add_token(pos);
+                continue;
+            }
+
+            // Pattern 7: \s+(?!\S) (trailing whitespace)
+            if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
+                pos += num_whitespaces - 1;
+                _add_token(pos);
+                continue;
+            }
+
+            // Pattern 8: \s+ (general whitespace)
+            if (num_whitespaces > 0) {
+                pos += num_whitespaces;
+                _add_token(pos);
+                continue;
+            }
+
+            // No matches - consume single character
+            _add_token(++pos);
+        }
+    }
+
+    return bpe_offsets;
+}
+
 static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
     std::vector<size_t> bpe_offsets;
 
@@ -567,6 +739,9 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
             regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
 
         bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
+    } else if (regex_expr == "\\p{Han}+") {
+        // K2's first pattern - handle all K2 patterns together
+        bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
     }
 
     return bpe_offsets;
@@ -672,6 +847,38 @@ uint32_t unicode_tolower(uint32_t cpt) {
     return cpt;  // Return the original code point if no lowercase mapping is found
 }
 
+bool unicode_cpt_is_han(uint32_t cpt) {
+    // Han character ranges (Chinese/CJK characters)
+    // CJK Unified Ideographs (most common)
+    if (cpt >= 0x4E00 && cpt <= 0x9FFF) return true;
+
+    // CJK Extension A
+    if (cpt >= 0x3400 && cpt <= 0x4DBF) return true;
+
+    // CJK Extension B
+    if (cpt >= 0x20000 && cpt <= 0x2A6DF) return true;
+
+    // CJK Extension C
+    if (cpt >= 0x2A700 && cpt <= 0x2B73F) return true;
+
+    // CJK Extension D
+    if (cpt >= 0x2B740 && cpt <= 0x2B81F) return true;
+
+    // CJK Extension E
+    if (cpt >= 0x2B820 && cpt <= 0x2CEAF) return true;
+
+    // CJK Extension F
+    if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF) return true;
+
+    // CJK Compatibility Ideographs
+    if (cpt >= 0xF900 && cpt <= 0xFAFF) return true;
+
+    // CJK Compatibility Ideographs Supplement
+    if (cpt >= 0x2F800 && cpt <= 0x2FA1F) return true;
+
+    return false;
+}
+
 std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
     // unicode categories
     static const std::map<std::string, int> k_ucat_enum = {
index c27098df7d4bec46cd06b1542d3f99c393755527..0a5fa2a78ceff3d98031110d45d495b17f56626e 100644 (file)
@@ -63,4 +63,6 @@ uint8_t     unicode_utf8_to_byte(const std::string & utf8);
 
 uint32_t unicode_tolower(uint32_t cpt);
 
+bool unicode_cpt_is_han(uint32_t cpt);
+
 std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);